diff --git a/auto-configurations/common/spring-ai-autoconfigure-retry/src/main/java/org/springframework/ai/retry/autoconfigure/SpringAiRetryAutoConfiguration.java b/auto-configurations/common/spring-ai-autoconfigure-retry/src/main/java/org/springframework/ai/retry/autoconfigure/SpringAiRetryAutoConfiguration.java index ae6de6cd664..1a1cef42d90 100644 --- a/auto-configurations/common/spring-ai-autoconfigure-retry/src/main/java/org/springframework/ai/retry/autoconfigure/SpringAiRetryAutoConfiguration.java +++ b/auto-configurations/common/spring-ai-autoconfigure-retry/src/main/java/org/springframework/ai/retry/autoconfigure/SpringAiRetryAutoConfiguration.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.net.URI; import java.nio.charset.StandardCharsets; +import java.util.concurrent.atomic.AtomicInteger; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -31,13 +32,13 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; +import org.springframework.core.retry.RetryListener; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.core.retry.Retryable; import org.springframework.http.HttpMethod; import org.springframework.http.client.ClientHttpResponse; import org.springframework.lang.NonNull; -import org.springframework.retry.RetryCallback; -import org.springframework.retry.RetryContext; -import org.springframework.retry.RetryListener; -import org.springframework.retry.support.RetryTemplate; import org.springframework.util.CollectionUtils; import org.springframework.util.StreamUtils; import org.springframework.web.client.ResponseErrorHandler; @@ -60,21 +61,25 @@ public class SpringAiRetryAutoConfiguration { @Bean @ConditionalOnMissingBean public RetryTemplate retryTemplate(SpringAiRetryProperties properties) { - return RetryTemplate.builder() + RetryPolicy retryPolicy = RetryPolicy.builder() .maxAttempts(properties.getMaxAttempts()) - .retryOn(TransientAiException.class) - .exponentialBackoff(properties.getBackoff().getInitialInterval(), properties.getBackoff().getMultiplier(), - properties.getBackoff().getMaxInterval()) - .withListener(new RetryListener() { - - @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - logger.warn("Retry error. Retry count: {}, Exception: {}", context.getRetryCount(), - throwable.getMessage(), throwable); - } - }) + .includes(TransientAiException.class) + .delay(properties.getBackoff().getInitialInterval()) + .multiplier(properties.getBackoff().getMultiplier()) + .maxDelay(properties.getBackoff().getMaxInterval()) .build(); + + RetryTemplate retryTemplate = new RetryTemplate(retryPolicy); + retryTemplate.setRetryListener(new RetryListener() { + private final AtomicInteger retryCount = new AtomicInteger(0); + + @Override + public void onRetryFailure(RetryPolicy policy, Retryable retryable, Throwable throwable) { + int currentRetries = this.retryCount.incrementAndGet(); + logger.warn("Retry error. Retry count:{}", currentRetries, throwable); + } + }); + return retryTemplate; } @Bean diff --git a/auto-configurations/common/spring-ai-autoconfigure-retry/src/test/java/org/springframework/ai/retry/autoconfigure/SpringAiRetryAutoConfigurationIT.java b/auto-configurations/common/spring-ai-autoconfigure-retry/src/test/java/org/springframework/ai/retry/autoconfigure/SpringAiRetryAutoConfigurationIT.java index 744b139c138..dd4004dbb6b 100644 --- a/auto-configurations/common/spring-ai-autoconfigure-retry/src/test/java/org/springframework/ai/retry/autoconfigure/SpringAiRetryAutoConfigurationIT.java +++ b/auto-configurations/common/spring-ai-autoconfigure-retry/src/test/java/org/springframework/ai/retry/autoconfigure/SpringAiRetryAutoConfigurationIT.java @@ -21,7 +21,7 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.web.client.ResponseErrorHandler; import static org.assertj.core.api.Assertions.assertThat; diff --git a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cosmos-db/pom.xml b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cosmos-db/pom.xml index 49c2597d70b..da7cb85c488 100644 --- a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cosmos-db/pom.xml +++ b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cosmos-db/pom.xml @@ -56,11 +56,6 @@ spring-boot-autoconfigure - - org.springframework - spring-core - - org.springframework.boot spring-boot-configuration-processor @@ -86,6 +81,11 @@ ${azure-identity.version} + + org.slf4j + jcl-over-slf4j + + org.springframework.boot diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/pom.xml b/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/pom.xml index b901c50a174..0eee243a13f 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/pom.xml +++ b/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/pom.xml @@ -83,12 +83,6 @@ true - - org.springframework.boot - spring-boot-starter-webflux - true - - org.springframework.ai diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/src/main/java/org/springframework/ai/model/anthropic/autoconfigure/AnthropicChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/src/main/java/org/springframework/ai/model/anthropic/autoconfigure/AnthropicChatAutoConfiguration.java index 7710156343b..0ec1bf2636f 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/src/main/java/org/springframework/ai/model/anthropic/autoconfigure/AnthropicChatAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/src/main/java/org/springframework/ai/model/anthropic/autoconfigure/AnthropicChatAutoConfiguration.java @@ -38,7 +38,7 @@ import org.springframework.boot.webclient.autoconfigure.WebClientAutoConfiguration; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-deepseek/src/main/java/org/springframework/ai/model/deepseek/autoconfigure/DeepSeekChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-deepseek/src/main/java/org/springframework/ai/model/deepseek/autoconfigure/DeepSeekChatAutoConfiguration.java index afacd4fe45a..eb84d57ee65 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-deepseek/src/main/java/org/springframework/ai/model/deepseek/autoconfigure/DeepSeekChatAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-deepseek/src/main/java/org/springframework/ai/model/deepseek/autoconfigure/DeepSeekChatAutoConfiguration.java @@ -38,7 +38,7 @@ import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.boot.webclient.autoconfigure.WebClientAutoConfiguration; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-elevenlabs/pom.xml b/auto-configurations/models/spring-ai-autoconfigure-model-elevenlabs/pom.xml index f21501672b1..af77303b055 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-elevenlabs/pom.xml +++ b/auto-configurations/models/spring-ai-autoconfigure-model-elevenlabs/pom.xml @@ -77,12 +77,6 @@ true - - org.springframework.boot - spring-boot-starter-webflux - true - - org.springframework.ai diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-elevenlabs/src/main/java/org/springframework/ai/model/elevenlabs/autoconfigure/ElevenLabsAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-elevenlabs/src/main/java/org/springframework/ai/model/elevenlabs/autoconfigure/ElevenLabsAutoConfiguration.java index 098ed50c590..ad0c20eecf5 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-elevenlabs/src/main/java/org/springframework/ai/model/elevenlabs/autoconfigure/ElevenLabsAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-elevenlabs/src/main/java/org/springframework/ai/model/elevenlabs/autoconfigure/ElevenLabsAutoConfiguration.java @@ -28,7 +28,7 @@ import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.boot.webclient.autoconfigure.WebClientAutoConfiguration; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-google-genai/src/main/java/org/springframework/ai/model/google/genai/autoconfigure/chat/GoogleGenAiChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-google-genai/src/main/java/org/springframework/ai/model/google/genai/autoconfigure/chat/GoogleGenAiChatAutoConfiguration.java index e8c9b8ea2d1..664b1e3b9be 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-google-genai/src/main/java/org/springframework/ai/model/google/genai/autoconfigure/chat/GoogleGenAiChatAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-google-genai/src/main/java/org/springframework/ai/model/google/genai/autoconfigure/chat/GoogleGenAiChatAutoConfiguration.java @@ -42,7 +42,7 @@ import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Conditional; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-google-genai/src/main/java/org/springframework/ai/model/google/genai/autoconfigure/embedding/GoogleGenAiTextEmbeddingAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-google-genai/src/main/java/org/springframework/ai/model/google/genai/autoconfigure/embedding/GoogleGenAiTextEmbeddingAutoConfiguration.java index 6fd62c663e4..d3825836171 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-google-genai/src/main/java/org/springframework/ai/model/google/genai/autoconfigure/embedding/GoogleGenAiTextEmbeddingAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-google-genai/src/main/java/org/springframework/ai/model/google/genai/autoconfigure/embedding/GoogleGenAiTextEmbeddingAutoConfiguration.java @@ -31,7 +31,7 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; /** * Auto-configuration for Google GenAI Text Embedding. diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-minimax/src/main/java/org/springframework/ai/model/minimax/autoconfigure/MiniMaxChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-minimax/src/main/java/org/springframework/ai/model/minimax/autoconfigure/MiniMaxChatAutoConfiguration.java index f8abc3cc6ac..7cebce533fd 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-minimax/src/main/java/org/springframework/ai/model/minimax/autoconfigure/MiniMaxChatAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-minimax/src/main/java/org/springframework/ai/model/minimax/autoconfigure/MiniMaxChatAutoConfiguration.java @@ -36,7 +36,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-minimax/src/main/java/org/springframework/ai/model/minimax/autoconfigure/MiniMaxEmbeddingAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-minimax/src/main/java/org/springframework/ai/model/minimax/autoconfigure/MiniMaxEmbeddingAutoConfiguration.java index d6cd0a94213..36e3a82cb00 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-minimax/src/main/java/org/springframework/ai/model/minimax/autoconfigure/MiniMaxEmbeddingAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-minimax/src/main/java/org/springframework/ai/model/minimax/autoconfigure/MiniMaxEmbeddingAutoConfiguration.java @@ -32,7 +32,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/main/java/org/springframework/ai/model/mistralai/autoconfigure/MistralAiChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/main/java/org/springframework/ai/model/mistralai/autoconfigure/MistralAiChatAutoConfiguration.java index 3d83d3e4eac..8ae982b503f 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/main/java/org/springframework/ai/model/mistralai/autoconfigure/MistralAiChatAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/main/java/org/springframework/ai/model/mistralai/autoconfigure/MistralAiChatAutoConfiguration.java @@ -37,7 +37,7 @@ import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.boot.webclient.autoconfigure.WebClientAutoConfiguration; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/main/java/org/springframework/ai/model/mistralai/autoconfigure/MistralAiEmbeddingAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/main/java/org/springframework/ai/model/mistralai/autoconfigure/MistralAiEmbeddingAutoConfiguration.java index b08d1955037..efa30bedd9d 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/main/java/org/springframework/ai/model/mistralai/autoconfigure/MistralAiEmbeddingAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/main/java/org/springframework/ai/model/mistralai/autoconfigure/MistralAiEmbeddingAutoConfiguration.java @@ -32,7 +32,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/main/java/org/springframework/ai/model/mistralai/autoconfigure/MistralAiModerationAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/main/java/org/springframework/ai/model/mistralai/autoconfigure/MistralAiModerationAutoConfiguration.java index 31cd04f7fc3..cb9d89db164 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/main/java/org/springframework/ai/model/mistralai/autoconfigure/MistralAiModerationAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/main/java/org/springframework/ai/model/mistralai/autoconfigure/MistralAiModerationAutoConfiguration.java @@ -31,7 +31,7 @@ import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.boot.webclient.autoconfigure.WebClientAutoConfiguration; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaChatAutoConfiguration.java index 34b9ad58346..67bdae15714 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaChatAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaChatAutoConfiguration.java @@ -36,7 +36,7 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; /** * {@link AutoConfiguration Auto-configuration} for Ollama Chat model. diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAIAutoConfigurationUtil.java b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAIAutoConfigurationUtil.java index 7eff1898fa4..2447478f4fd 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAIAutoConfigurationUtil.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAIAutoConfigurationUtil.java @@ -16,14 +16,9 @@ package org.springframework.ai.model.openai.autoconfigure; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - +import org.springframework.http.HttpHeaders; import org.springframework.lang.NonNull; import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; -import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; public final class OpenAIAutoConfigurationUtil { @@ -44,12 +39,12 @@ private OpenAIAutoConfigurationUtil() { String organizationId = StringUtils.hasText(modelProperties.getOrganizationId()) ? modelProperties.getOrganizationId() : commonProperties.getOrganizationId(); - Map> connectionHeaders = new HashMap<>(); + HttpHeaders connectionHeaders = new HttpHeaders(); if (StringUtils.hasText(projectId)) { - connectionHeaders.put("OpenAI-Project", List.of(projectId)); + connectionHeaders.add("OpenAI-Project", projectId); } if (StringUtils.hasText(organizationId)) { - connectionHeaders.put("OpenAI-Organization", List.of(organizationId)); + connectionHeaders.add("OpenAI-Organization", organizationId); } Assert.hasText(baseUrl, @@ -59,10 +54,10 @@ private OpenAIAutoConfigurationUtil() { "OpenAI API key must be set. Use the connection property: spring.ai.openai.api-key or spring.ai.openai." + modelType + ".api-key property."); - return new ResolvedConnectionProperties(baseUrl, apiKey, CollectionUtils.toMultiValueMap(connectionHeaders)); + return new ResolvedConnectionProperties(baseUrl, apiKey, connectionHeaders); } - public record ResolvedConnectionProperties(String baseUrl, String apiKey, MultiValueMap headers) { + public record ResolvedConnectionProperties(String baseUrl, String apiKey, HttpHeaders headers) { } diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiAudioSpeechAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiAudioSpeechAutoConfiguration.java index 03f16b890cc..a4623072f90 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiAudioSpeechAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiAudioSpeechAutoConfiguration.java @@ -32,7 +32,7 @@ import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.boot.webclient.autoconfigure.WebClientAutoConfiguration; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiAudioTranscriptionAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiAudioTranscriptionAutoConfiguration.java index efdf9d3d756..acba2f5d927 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiAudioTranscriptionAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiAudioTranscriptionAutoConfiguration.java @@ -32,7 +32,7 @@ import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.boot.webclient.autoconfigure.WebClientAutoConfiguration; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiChatAutoConfiguration.java index b35a1e95999..f8f5f801a11 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiChatAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiChatAutoConfiguration.java @@ -38,7 +38,7 @@ import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.boot.webclient.autoconfigure.WebClientAutoConfiguration; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiEmbeddingAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiEmbeddingAutoConfiguration.java index d655d1d9571..ac85dbdc248 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiEmbeddingAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiEmbeddingAutoConfiguration.java @@ -34,7 +34,7 @@ import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.boot.webclient.autoconfigure.WebClientAutoConfiguration; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiImageAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiImageAutoConfiguration.java index 34f0f656370..ba7ee1f8c11 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiImageAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiImageAutoConfiguration.java @@ -35,7 +35,7 @@ import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.boot.webclient.autoconfigure.WebClientAutoConfiguration; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiModerationAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiModerationAutoConfiguration.java index 4df477f1f3f..f844f4d5ec8 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiModerationAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/main/java/org/springframework/ai/model/openai/autoconfigure/OpenAiModerationAutoConfiguration.java @@ -32,7 +32,7 @@ import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.boot.webclient.autoconfigure.WebClientAutoConfiguration; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-vertex-ai/src/main/java/org/springframework/ai/model/vertexai/autoconfigure/embedding/VertexAiTextEmbeddingAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-vertex-ai/src/main/java/org/springframework/ai/model/vertexai/autoconfigure/embedding/VertexAiTextEmbeddingAutoConfiguration.java index 56e253dc2d2..8c0370ec84d 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-vertex-ai/src/main/java/org/springframework/ai/model/vertexai/autoconfigure/embedding/VertexAiTextEmbeddingAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-vertex-ai/src/main/java/org/springframework/ai/model/vertexai/autoconfigure/embedding/VertexAiTextEmbeddingAutoConfiguration.java @@ -31,7 +31,7 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; /** * Auto-configuration for Vertex AI Gemini Chat. diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-vertex-ai/src/main/java/org/springframework/ai/model/vertexai/autoconfigure/gemini/VertexAiGeminiChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-vertex-ai/src/main/java/org/springframework/ai/model/vertexai/autoconfigure/gemini/VertexAiGeminiChatAutoConfiguration.java index edb7057c1e6..79b4865ca8e 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-vertex-ai/src/main/java/org/springframework/ai/model/vertexai/autoconfigure/gemini/VertexAiGeminiChatAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-vertex-ai/src/main/java/org/springframework/ai/model/vertexai/autoconfigure/gemini/VertexAiGeminiChatAutoConfiguration.java @@ -39,7 +39,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/pom.xml b/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/pom.xml index ff3c59e3c97..92811dcf397 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/pom.xml +++ b/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/pom.xml @@ -90,12 +90,6 @@ true - - org.springframework.boot - spring-boot-starter-webflux - true - - org.springframework.ai diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiChatAutoConfiguration.java index 683e0683a5b..4023978c697 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiChatAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiChatAutoConfiguration.java @@ -37,7 +37,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiEmbeddingAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiEmbeddingAutoConfiguration.java index 1bbb866574f..94ab54dfda9 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiEmbeddingAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiEmbeddingAutoConfiguration.java @@ -33,7 +33,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiImageAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiImageAutoConfiguration.java index 61c66c1f588..d1e876f2afd 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiImageAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiImageAutoConfiguration.java @@ -30,7 +30,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.restclient.autoconfigure.RestClientAutoConfiguration; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-couchbase/src/main/java/org/springframework/ai/vectorstore/couchbase/autoconfigure/CouchbaseSearchVectorStoreAutoConfiguration.java b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-couchbase/src/main/java/org/springframework/ai/vectorstore/couchbase/autoconfigure/CouchbaseSearchVectorStoreAutoConfiguration.java index 8d84e081b71..63e31484166 100644 --- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-couchbase/src/main/java/org/springframework/ai/vectorstore/couchbase/autoconfigure/CouchbaseSearchVectorStoreAutoConfiguration.java +++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-couchbase/src/main/java/org/springframework/ai/vectorstore/couchbase/autoconfigure/CouchbaseSearchVectorStoreAutoConfiguration.java @@ -16,8 +16,6 @@ package org.springframework.ai.vectorstore.couchbase.autoconfigure; -import java.util.Objects; - import com.couchbase.client.java.Cluster; import org.springframework.ai.embedding.EmbeddingModel; diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-couchbase/src/test/java/org/springframework/ai/vectorstore/couchbase/autoconfigure/CouchbaseSearchVectorStoreAutoConfigurationIT.java b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-couchbase/src/test/java/org/springframework/ai/vectorstore/couchbase/autoconfigure/CouchbaseSearchVectorStoreAutoConfigurationIT.java index d5bc6980b11..902f44353a5 100644 --- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-couchbase/src/test/java/org/springframework/ai/vectorstore/couchbase/autoconfigure/CouchbaseSearchVectorStoreAutoConfigurationIT.java +++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-couchbase/src/test/java/org/springframework/ai/vectorstore/couchbase/autoconfigure/CouchbaseSearchVectorStoreAutoConfigurationIT.java @@ -63,8 +63,8 @@ class CouchbaseSearchVectorStoreAutoConfigurationIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(CouchbaseAutoConfiguration.class, - CouchbaseSearchVectorStoreAutoConfiguration.class, - SpringAiRetryAutoConfiguration.class, OpenAiEmbeddingAutoConfiguration.class)) + CouchbaseSearchVectorStoreAutoConfiguration.class, SpringAiRetryAutoConfiguration.class, + OpenAiEmbeddingAutoConfiguration.class)) .withPropertyValues("spring.couchbase.connection-string=" + couchbaseContainer.getConnectionString(), "spring.couchbase.username=" + couchbaseContainer.getUsername(), "spring.couchbase.password=" + couchbaseContainer.getPassword(), @@ -112,8 +112,8 @@ public void addAndSearchWithFilters() { public void propertiesTest() { new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(CouchbaseAutoConfiguration.class, - CouchbaseSearchVectorStoreAutoConfiguration.class, - SpringAiRetryAutoConfiguration.class, OpenAiEmbeddingAutoConfiguration.class)) + CouchbaseSearchVectorStoreAutoConfiguration.class, SpringAiRetryAutoConfiguration.class, + OpenAiEmbeddingAutoConfiguration.class)) .withPropertyValues("spring.couchbase.connection-string=" + couchbaseContainer.getConnectionString(), "spring.couchbase.username=" + couchbaseContainer.getUsername(), "spring.couchbase.password=" + couchbaseContainer.getPassword(), diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-redis/pom.xml b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-redis/pom.xml index f23ec23b061..6b57660413e 100644 --- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-redis/pom.xml +++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-redis/pom.xml @@ -52,6 +52,7 @@ org.springframework.boot spring-boot-starter-data-redis + true org.springframework.boot @@ -60,7 +61,7 @@ org.springframework.boot - spring-boot-autoconfigure-processor + spring-boot-data-redis true diff --git a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/pom.xml b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/pom.xml index 7eed6839006..420d77b50fe 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/pom.xml +++ b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/pom.xml @@ -117,11 +117,6 @@ spring-boot-starter-test test - - org.springframework.boot - spring-boot-jdbc - test - org.testcontainers @@ -132,7 +127,6 @@ org.testcontainers testcontainers-oracle-free - 2.0.1 test diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 40010e11ad9..f7e0567fd03 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -78,11 +78,12 @@ import org.springframework.ai.support.UsageCalculator; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.util.json.JsonParser; +import org.springframework.core.retry.RetryException; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.http.HttpHeaders; import org.springframework.http.ResponseEntity; -import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; /** @@ -193,8 +194,19 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons this.observationRegistry) .observe(() -> { - ResponseEntity completionEntity = this.retryTemplate.execute( - ctx -> this.anthropicApi.chatCompletionEntity(request, this.getAdditionalHttpHeaders(prompt))); + ResponseEntity completionEntity = null; + try { + completionEntity = this.retryTemplate.execute(() -> this.anthropicApi.chatCompletionEntity(request, + this.getAdditionalHttpHeaders(prompt))); + } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; + } + else { + throw new RuntimeException(e.getCause()); + } + } AnthropicApi.ChatCompletionResponse completionResponse = completionEntity.getBody(); AnthropicApi.Usage usage = completionResponse.usage(); @@ -523,14 +535,15 @@ else if (mimeType.contains("pdf")) { + ". Supported types are: images (image/*) and PDF documents (application/pdf)"); } - private MultiValueMap getAdditionalHttpHeaders(Prompt prompt) { + private HttpHeaders getAdditionalHttpHeaders(Prompt prompt) { Map headers = new HashMap<>(this.defaultOptions.getHttpHeaders()); if (prompt.getOptions() != null && prompt.getOptions() instanceof AnthropicChatOptions chatOptions) { headers.putAll(chatOptions.getHttpHeaders()); } - return CollectionUtils.toMultiValueMap( - headers.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> List.of(e.getValue())))); + HttpHeaders httpHeaders = new HttpHeaders(); + headers.forEach(httpHeaders::add); + return httpHeaders; } Prompt buildRequestPrompt(Prompt prompt) { diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index 824cfef2b83..e18c4d38801 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -48,8 +48,6 @@ import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; @@ -169,7 +167,7 @@ public AnthropicApi(String completionsPath, RestClient restClient, WebClient web * status code and headers. */ public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { - return chatCompletionEntity(chatRequest, new LinkedMultiValueMap<>()); + return chatCompletionEntity(chatRequest, new HttpHeaders()); } /** @@ -180,7 +178,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletio * status code and headers. */ public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest, - MultiValueMap additionalHttpHeader) { + HttpHeaders additionalHttpHeader) { Assert.notNull(chatRequest, "The request body can not be null."); Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); @@ -190,7 +188,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletio return this.restClient.post() .uri(this.completionsPath) .headers(headers -> { - headers.addAll(HttpHeaders.readOnlyHttpHeaders(additionalHttpHeader)); + headers.addAll(additionalHttpHeader); addDefaultHeadersIfMissing(headers); }) .body(chatRequest) @@ -206,7 +204,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletio * @return Returns a {@link Flux} stream from chat completion chunks. */ public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { - return chatCompletionStream(chatRequest, new LinkedMultiValueMap<>()); + return chatCompletionStream(chatRequest, new HttpHeaders()); } /** @@ -217,7 +215,7 @@ public Flux chatCompletionStream(ChatCompletionRequest c * @return Returns a {@link Flux} stream from chat completion chunks. */ public Flux chatCompletionStream(ChatCompletionRequest chatRequest, - MultiValueMap additionalHttpHeader) { + HttpHeaders additionalHttpHeader) { Assert.notNull(chatRequest, "The request body can not be null."); Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); @@ -231,7 +229,7 @@ public Flux chatCompletionStream(ChatCompletionRequest c return this.webClient.post() .uri(this.completionsPath) .headers(headers -> { - headers.addAll(HttpHeaders.readOnlyHttpHeaders(additionalHttpHeader)); + headers.addAll(additionalHttpHeader); addDefaultHeadersIfMissing(headers); }) // @formatter:off .body(Mono.just(chatRequest), ChatCompletionRequest.class) @@ -270,7 +268,7 @@ public Flux chatCompletionStream(ChatCompletionRequest c } private void addDefaultHeadersIfMissing(HttpHeaders headers) { - if (null == headers.getFirst(HEADER_X_API_KEY)) { + if (!headers.containsHeader(HEADER_X_API_KEY)) { String apiKeyValue = this.apiKey.getValue(); if (StringUtils.hasText(apiKeyValue)) { headers.add(HEADER_X_API_KEY, apiKeyValue); diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java index 910f572d208..382dbe6f64f 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java @@ -40,7 +40,7 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; @@ -171,7 +171,7 @@ public AnthropicApi anthropicApi() { public AnthropicChatModel anthropicChatModel(AnthropicApi anthropicApi, TestObservationRegistry observationRegistry) { return new AnthropicChatModel(anthropicApi, AnthropicChatOptions.builder().build(), - ToolCallingManager.builder().build(), RetryTemplate.defaultInstance(), observationRegistry); + ToolCallingManager.builder().build(), new RetryTemplate(), observationRegistry); } } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiBuilderTests.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiBuilderTests.java index 1f42a35a67e..6b2a1caf8d1 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiBuilderTests.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiBuilderTests.java @@ -37,8 +37,6 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; @@ -304,7 +302,7 @@ void dynamicApiKeyRestClientWithAdditionalApiKeyHeader() throws InterruptedExcep .temperature(0.8) .messages(List.of(chatCompletionMessage)) .build(); - MultiValueMap additionalHeaders = new LinkedMultiValueMap<>(); + var additionalHeaders = new HttpHeaders(); additionalHeaders.add("x-api-key", "additional-key"); ResponseEntity response = api.chatCompletionEntity(request, additionalHeaders); @@ -403,7 +401,7 @@ void dynamicApiKeyWebClientWithAdditionalApiKey() throws InterruptedException { .messages(List.of(chatCompletionMessage)) .stream(true) .build(); - MultiValueMap additionalHeaders = new LinkedMultiValueMap<>(); + var additionalHeaders = new HttpHeaders(); additionalHeaders.add("x-api-key", "additional-key"); api.chatCompletionStream(request, additionalHeaders).collectList().block(); diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java index fba44ffd4ce..c89a1c0c812 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java @@ -66,8 +66,9 @@ import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.support.UsageCalculator; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.core.retry.RetryException; +import org.springframework.core.retry.RetryTemplate; import org.springframework.http.ResponseEntity; -import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -165,8 +166,18 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons this.observationRegistry) .observe(() -> { - ResponseEntity completionEntity = this.retryTemplate - .execute(ctx -> this.deepSeekApi.chatCompletionEntity(request)); + ResponseEntity completionEntity = null; + try { + completionEntity = this.retryTemplate.execute(() -> this.deepSeekApi.chatCompletionEntity(request)); + } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; + } + else { + throw new RuntimeException(e.getCause()); + } + } var chatCompletion = completionEntity.getBody(); diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java index cdb061a78da..13667f0e3d5 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java @@ -41,8 +41,6 @@ import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; @@ -80,7 +78,7 @@ public class DeepSeekApi { * @param webClientBuilder WebClient builder. * @param responseErrorHandler Response error handler. */ - public DeepSeekApi(String baseUrl, ApiKey apiKey, MultiValueMap headers, String completionsPath, + public DeepSeekApi(String baseUrl, ApiKey apiKey, HttpHeaders headers, String completionsPath, String betaPrefixPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { @@ -150,7 +148,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest * @return Returns a {@link Flux} stream from chat completion chunks. */ public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { - return chatCompletionStream(chatRequest, new LinkedMultiValueMap<>()); + return chatCompletionStream(chatRequest, new HttpHeaders()); } /** @@ -162,7 +160,7 @@ public Flux chatCompletionStream(ChatCompletionRequest chat * @return Returns a {@link Flux} stream from chat completion chunks. */ public Flux chatCompletionStream(ChatCompletionRequest chatRequest, - MultiValueMap additionalHttpHeader) { + HttpHeaders additionalHttpHeader) { Assert.notNull(chatRequest, "The request body can not be null."); Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); @@ -929,7 +927,7 @@ public static final class Builder { private ApiKey apiKey; - private MultiValueMap headers = new LinkedMultiValueMap<>(); + private HttpHeaders headers = new HttpHeaders(); private String completionsPath = org.springframework.ai.deepseek.api.common.DeepSeekConstants.DEFAULT_COMPLETIONS_PATH; @@ -959,7 +957,7 @@ public Builder apiKey(String simpleApiKey) { return this; } - public Builder headers(MultiValueMap headers) { + public Builder headers(HttpHeaders headers) { Assert.notNull(headers, "headers cannot be null"); this.headers = headers; return this; diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekRetryTests.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekRetryTests.java index 35f24eaadeb..511ee734806 100644 --- a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekRetryTests.java +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekRetryTests.java @@ -34,11 +34,11 @@ import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionRequest; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.retry.TransientAiException; +import org.springframework.core.retry.RetryListener; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.core.retry.Retryable; import org.springframework.http.ResponseEntity; -import org.springframework.retry.RetryCallback; -import org.springframework.retry.RetryContext; -import org.springframework.retry.RetryListener; -import org.springframework.retry.support.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -62,7 +62,7 @@ public class DeepSeekRetryTests { public void beforeEach() { RetryTemplate retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE; this.retryListener = new TestRetryListener(); - retryTemplate.registerListener(this.retryListener); + retryTemplate.setRetryListener(this.retryListener); this.chatModel = DeepSeekChatModel.builder() .deepSeekApi(this.deepSeekApi) @@ -88,7 +88,7 @@ public void deepSeekChatTransientError() { assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getText()).isSameAs("Response"); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @@ -116,7 +116,7 @@ public void deepSeekChatStreamTransientError() { assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getText()).isSameAs("Response"); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @@ -134,14 +134,15 @@ private static class TestRetryListener implements RetryListener { int onSuccessRetryCount = 0; @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - this.onSuccessRetryCount = context.getRetryCount(); + public void beforeRetry(final RetryPolicy retryPolicy, final Retryable retryable) { + // Count each retry attempt + this.onErrorRetryCount++; } @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - this.onErrorRetryCount = context.getRetryCount(); + public void onRetrySuccess(final RetryPolicy retryPolicy, final Retryable retryable, final Object result) { + // Count successful retries - we increment when we succeed after a failure + this.onSuccessRetryCount++; } } diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelObservationIT.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelObservationIT.java index 00e8732dade..b2900e479cb 100644 --- a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelObservationIT.java +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelObservationIT.java @@ -40,7 +40,7 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -172,7 +172,7 @@ public DeepSeekApi deepSeekApi() { public DeepSeekChatModel deepSeekChatModel(DeepSeekApi deepSeekApi, TestObservationRegistry observationRegistry) { return new DeepSeekChatModel(deepSeekApi, DeepSeekChatOptions.builder().build(), - ToolCallingManager.builder().build(), RetryTemplate.defaultInstance(), observationRegistry); + ToolCallingManager.builder().build(), new RetryTemplate(), observationRegistry); } } diff --git a/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/ElevenLabsTextToSpeechModel.java b/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/ElevenLabsTextToSpeechModel.java index 68ed07568a8..b8220e87065 100644 --- a/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/ElevenLabsTextToSpeechModel.java +++ b/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/ElevenLabsTextToSpeechModel.java @@ -29,7 +29,8 @@ import org.springframework.ai.audio.tts.TextToSpeechResponse; import org.springframework.ai.elevenlabs.api.ElevenLabsApi; import org.springframework.ai.retry.RetryUtils; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryException; +import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -73,15 +74,26 @@ public static Builder builder() { public TextToSpeechResponse call(TextToSpeechPrompt prompt) { RequestContext requestContext = prepareRequest(prompt); - byte[] audioData = this.retryTemplate.execute(context -> { - var response = this.elevenLabsApi.textToSpeech(requestContext.request, requestContext.voiceId, - requestContext.queryParameters); - if (response.getBody() == null) { - logger.warn("No speech response returned for request: {}", requestContext.request); - return new byte[0]; + byte[] audioData = null; + try { + audioData = this.retryTemplate.execute(() -> { + var response = this.elevenLabsApi.textToSpeech(requestContext.request, requestContext.voiceId, + requestContext.queryParameters); + if (response.getBody() == null) { + logger.warn("No speech response returned for request: {}", requestContext.request); + return new byte[0]; + } + return response.getBody(); + }); + } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; + } + else { + throw new RuntimeException(e.getCause()); } - return response.getBody(); - }); + } return new TextToSpeechResponse(List.of(new Speech(audioData))); } @@ -90,9 +102,19 @@ public TextToSpeechResponse call(TextToSpeechPrompt prompt) { public Flux stream(TextToSpeechPrompt prompt) { RequestContext requestContext = prepareRequest(prompt); - return this.retryTemplate.execute(context -> this.elevenLabsApi - .textToSpeechStream(requestContext.request, requestContext.voiceId, requestContext.queryParameters) - .map(entity -> new TextToSpeechResponse(List.of(new Speech(entity.getBody()))))); + try { + return this.retryTemplate.execute(() -> this.elevenLabsApi + .textToSpeechStream(requestContext.request, requestContext.voiceId, requestContext.queryParameters) + .map(entity -> new TextToSpeechResponse(List.of(new Speech(entity.getBody()))))); + } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; + } + else { + throw new RuntimeException(e.getCause()); + } + } } private RequestContext prepareRequest(TextToSpeechPrompt prompt) { diff --git a/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/api/ElevenLabsApi.java b/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/api/ElevenLabsApi.java index 691de26b690..2f20727e17c 100644 --- a/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/api/ElevenLabsApi.java +++ b/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/api/ElevenLabsApi.java @@ -33,7 +33,6 @@ import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; -import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; @@ -62,9 +61,8 @@ public final class ElevenLabsApi { * @param webClientBuilder A builder for the Spring WebClient. * @param responseErrorHandler A custom error handler for API responses. */ - private ElevenLabsApi(String baseUrl, ApiKey apiKey, MultiValueMap headers, - RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, - ResponseErrorHandler responseErrorHandler) { + private ElevenLabsApi(String baseUrl, ApiKey apiKey, HttpHeaders headers, RestClient.Builder restClientBuilder, + WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { Consumer jsonContentHeaders = h -> { if (!(apiKey instanceof NoopApiKey)) { @@ -341,7 +339,7 @@ public static final class Builder { private ApiKey apiKey; - private MultiValueMap headers = new LinkedMultiValueMap<>(); + private HttpHeaders headers = new HttpHeaders(); private RestClient.Builder restClientBuilder = RestClient.builder(); @@ -367,7 +365,7 @@ public Builder apiKey(String simpleApiKey) { return this; } - public Builder headers(MultiValueMap headers) { + public Builder headers(HttpHeaders headers) { Assert.notNull(headers, "headers cannot be null"); this.headers = headers; return this; diff --git a/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/api/ElevenLabsVoicesApi.java b/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/api/ElevenLabsVoicesApi.java index 2cffa16ae29..766191a8a73 100644 --- a/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/api/ElevenLabsVoicesApi.java +++ b/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/api/ElevenLabsVoicesApi.java @@ -32,8 +32,6 @@ import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; @@ -56,8 +54,8 @@ public class ElevenLabsVoicesApi { * @param restClientBuilder A builder for the Spring RestClient. * @param responseErrorHandler A custom error handler for API responses. */ - public ElevenLabsVoicesApi(String baseUrl, ApiKey apiKey, MultiValueMap headers, - RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { + public ElevenLabsVoicesApi(String baseUrl, ApiKey apiKey, HttpHeaders headers, RestClient.Builder restClientBuilder, + ResponseErrorHandler responseErrorHandler) { Consumer jsonContentHeaders = h -> { if (!(apiKey instanceof NoopApiKey)) { h.set("xi-api-key", apiKey.getValue()); @@ -407,7 +405,7 @@ public static final class Builder { private ApiKey apiKey; - private MultiValueMap headers = new LinkedMultiValueMap<>(); + private HttpHeaders headers = new HttpHeaders(); private RestClient.Builder restClientBuilder = RestClient.builder(); @@ -431,7 +429,7 @@ public Builder apiKey(String simpleApiKey) { return this; } - public Builder headers(MultiValueMap headers) { + public Builder headers(HttpHeaders headers) { Assert.notNull(headers, "headers cannot be null"); this.headers = headers; return this; diff --git a/models/spring-ai-google-genai-embedding/src/main/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingModel.java b/models/spring-ai-google-genai-embedding/src/main/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingModel.java index 46c87cd6862..049720082bf 100644 --- a/models/spring-ai-google-genai-embedding/src/main/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingModel.java +++ b/models/spring-ai-google-genai-embedding/src/main/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingModel.java @@ -46,7 +46,8 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.retry.RetryUtils; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryException; +import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -168,8 +169,19 @@ public EmbeddingResponse call(EmbeddingRequest request) { } // Call the embedding API with retry - EmbedContentResponse embeddingResponse = this.retryTemplate - .execute(context -> this.genAiClient.models.embedContent(modelName, validTexts, config)); + EmbedContentResponse embeddingResponse = null; + try { + embeddingResponse = this.retryTemplate + .execute(() -> this.genAiClient.models.embedContent(modelName, validTexts, config)); + } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; + } + else { + throw new RuntimeException(e.getCause()); + } + } // Process the response // Note: We need to handle the case where some texts were filtered out diff --git a/models/spring-ai-google-genai-embedding/src/test/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingRetryTests.java b/models/spring-ai-google-genai-embedding/src/test/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingRetryTests.java index 4dc9fce14c5..6d2136aca1a 100644 --- a/models/spring-ai-google-genai-embedding/src/test/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingRetryTests.java +++ b/models/spring-ai-google-genai-embedding/src/test/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingRetryTests.java @@ -36,10 +36,10 @@ import org.springframework.ai.google.genai.GoogleGenAiEmbeddingConnectionDetails; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.retry.TransientAiException; -import org.springframework.retry.RetryCallback; -import org.springframework.retry.RetryContext; -import org.springframework.retry.RetryListener; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryListener; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.core.retry.Retryable; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; @@ -75,7 +75,7 @@ public class GoogleGenAiTextEmbeddingRetryTests { public void setUp() throws Exception { this.retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE; this.retryListener = new TestRetryListener(); - this.retryTemplate.registerListener(this.retryListener); + this.retryTemplate.setRetryListener(this.retryListener); // Create a mock Client and use reflection to set the models field this.mockGenAiClient = mock(Client.class); @@ -114,7 +114,7 @@ public void vertexAiEmbeddingTransientError() { assertThat(result).isNotNull(); assertThat(result.getResults()).hasSize(1); assertThat(result.getResults().get(0).getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); verify(this.mockModels, times(3)).embedContent(anyString(), any(List.class), any(EmbedContentConfig.class)); @@ -143,14 +143,15 @@ private static class TestRetryListener implements RetryListener { int onSuccessRetryCount = 0; @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - this.onSuccessRetryCount = context.getRetryCount(); + public void beforeRetry(final RetryPolicy retryPolicy, final Retryable retryable) { + // Count each retry attempt + this.onErrorRetryCount++; } @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - this.onErrorRetryCount = context.getRetryCount(); + public void onRetrySuccess(final RetryPolicy retryPolicy, final Retryable retryable, final Object result) { + // Count successful retries - we increment when we succeed after a failure + this.onSuccessRetryCount++; } } diff --git a/models/spring-ai-google-genai-embedding/src/test/java/org/springframework/ai/google/genai/text/TestGoogleGenAiTextEmbeddingModel.java b/models/spring-ai-google-genai-embedding/src/test/java/org/springframework/ai/google/genai/text/TestGoogleGenAiTextEmbeddingModel.java index 44a06031afc..9836f63ec30 100644 --- a/models/spring-ai-google-genai-embedding/src/test/java/org/springframework/ai/google/genai/text/TestGoogleGenAiTextEmbeddingModel.java +++ b/models/spring-ai-google-genai-embedding/src/test/java/org/springframework/ai/google/genai/text/TestGoogleGenAiTextEmbeddingModel.java @@ -17,7 +17,7 @@ package org.springframework.ai.google.genai.text; import org.springframework.ai.google.genai.GoogleGenAiEmbeddingConnectionDetails; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; /** * Test implementation of GoogleGenAiTextEmbeddingModel that uses a mock connection for diff --git a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java index 8e38008e859..70f5c1385dc 100644 --- a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java +++ b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java @@ -87,8 +87,9 @@ import org.springframework.ai.support.UsageCalculator; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.beans.factory.DisposableBean; +import org.springframework.core.retry.RetryException; +import org.springframework.core.retry.RetryTemplate; import org.springframework.lang.NonNull; -import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -404,30 +405,41 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) - .observe(() -> this.retryTemplate.execute(context -> { - - var geminiRequest = createGeminiRequest(prompt); - - GenerateContentResponse generateContentResponse = this.getContentResponse(geminiRequest); + .observe(() -> { + try { + return this.retryTemplate.execute(() -> { + + var geminiRequest = createGeminiRequest(prompt); + + GenerateContentResponse generateContentResponse = this.getContentResponse(geminiRequest); + + List generations = generateContentResponse.candidates() + .orElse(List.of()) + .stream() + .map(this::responseCandidateToGeneration) + .flatMap(List::stream) + .toList(); + + var usage = generateContentResponse.usageMetadata(); + GoogleGenAiChatOptions options = (GoogleGenAiChatOptions) prompt.getOptions(); + Usage currentUsage = (usage.isPresent()) ? getDefaultUsage(usage.get(), options) + : getDefaultUsage(null, options); + Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(currentUsage, previousChatResponse); + ChatResponse chatResponse = new ChatResponse(generations, + toChatResponseMetadata(cumulativeUsage, generateContentResponse.modelVersion().get())); + + observationContext.setResponse(chatResponse); + return chatResponse; + }); + } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; + } - List generations = generateContentResponse.candidates() - .orElse(List.of()) - .stream() - .map(this::responseCandidateToGeneration) - .flatMap(List::stream) - .toList(); - - var usage = generateContentResponse.usageMetadata(); - GoogleGenAiChatOptions options = (GoogleGenAiChatOptions) prompt.getOptions(); - Usage currentUsage = (usage.isPresent()) ? getDefaultUsage(usage.get(), options) - : getDefaultUsage(null, options); - Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(currentUsage, previousChatResponse); - ChatResponse chatResponse = new ChatResponse(generations, - toChatResponseMetadata(cumulativeUsage, generateContentResponse.modelVersion().get())); - - observationContext.setResponse(chatResponse); - return chatResponse; - })); + throw new RuntimeException(e); + } + }); if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); diff --git a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatModelCachedContentTests.java b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatModelCachedContentTests.java index af4ea6679e0..f2340cc79d6 100644 --- a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatModelCachedContentTests.java +++ b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatModelCachedContentTests.java @@ -36,7 +36,7 @@ import org.springframework.ai.google.genai.cache.GoogleGenAiCachedContent; import org.springframework.ai.google.genai.cache.GoogleGenAiCachedContentService; import org.springframework.ai.retry.RetryUtils; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; diff --git a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatModelExtendedUsageTests.java b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatModelExtendedUsageTests.java index 27e3ccf24a5..1b75b991a3a 100644 --- a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatModelExtendedUsageTests.java +++ b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatModelExtendedUsageTests.java @@ -41,7 +41,7 @@ import org.springframework.ai.google.genai.metadata.GoogleGenAiTrafficType; import org.springframework.ai.google.genai.metadata.GoogleGenAiUsage; import org.springframework.ai.retry.RetryUtils; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; diff --git a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiRetryTests.java b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiRetryTests.java index 4170c992c64..901036959bc 100644 --- a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiRetryTests.java +++ b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiRetryTests.java @@ -27,10 +27,10 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.ai.retry.RetryUtils; -import org.springframework.retry.RetryCallback; -import org.springframework.retry.RetryContext; -import org.springframework.retry.RetryListener; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryListener; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.core.retry.Retryable; /** * @author Mark Pollack @@ -55,7 +55,7 @@ public class GoogleGenAiRetryTests { public void setUp() { this.retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE; this.retryListener = new TestRetryListener(); - this.retryTemplate.registerListener(this.retryListener); + this.retryTemplate.setRetryListener(this.retryListener); this.chatModel = new org.springframework.ai.google.genai.TestGoogleGenAiGeminiChatModel(this.genAiClient, GoogleGenAiChatOptions.builder() @@ -95,14 +95,15 @@ private static class TestRetryListener implements RetryListener { int onSuccessRetryCount = 0; @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - this.onSuccessRetryCount = context.getRetryCount(); + public void beforeRetry(final RetryPolicy retryPolicy, final Retryable retryable) { + // Count each retry attempt + this.onErrorRetryCount++; } @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - this.onErrorRetryCount = context.getRetryCount(); + public void onRetrySuccess(final RetryPolicy retryPolicy, final Retryable retryable, final Object result) { + // Count successful retries - we increment when we succeed after a failure + this.onSuccessRetryCount++; } } diff --git a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/TestGoogleGenAiGeminiChatModel.java b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/TestGoogleGenAiGeminiChatModel.java index 6c63133fd1f..2730b73e501 100644 --- a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/TestGoogleGenAiGeminiChatModel.java +++ b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/TestGoogleGenAiGeminiChatModel.java @@ -20,7 +20,7 @@ import com.google.genai.types.GenerateContentResponse; import org.springframework.ai.model.tool.ToolCallingManager; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; /** * @author Mark Pollack diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index 5c771b2f5db..fa72c428c86 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -68,8 +68,9 @@ import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.core.retry.RetryException; +import org.springframework.core.retry.RetryTemplate; import org.springframework.http.ResponseEntity; -import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -253,8 +254,18 @@ public ChatResponse call(Prompt prompt) { this.observationRegistry) .observe(() -> { - ResponseEntity completionEntity = this.retryTemplate - .execute(ctx -> this.miniMaxApi.chatCompletionEntity(request)); + ResponseEntity completionEntity = null; + try { + completionEntity = this.retryTemplate.execute(() -> this.miniMaxApi.chatCompletionEntity(request)); + } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; + } + else { + throw new RuntimeException(e.getCause()); + } + } var chatCompletion = completionEntity.getBody(); @@ -328,8 +339,18 @@ public Flux stream(Prompt prompt) { return Flux.deferContextual(contextView -> { ChatCompletionRequest request = createRequest(requestPrompt, true); - Flux completionChunks = this.retryTemplate - .execute(ctx -> this.miniMaxApi.chatCompletionStream(request)); + Flux completionChunks = null; + try { + completionChunks = this.retryTemplate.execute(() -> this.miniMaxApi.chatCompletionStream(request)); + } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; + } + else { + throw new RuntimeException(e.getCause()); + } + } // For chunked responses, only the first chunk contains the choice role. // The rest of the chunks with same ID share the same role. diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java index 5dd0077fed3..ab631ee1825 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java @@ -40,7 +40,8 @@ import org.springframework.ai.minimax.api.MiniMaxApiConstants; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.retry.RetryUtils; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryException; +import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -165,8 +166,19 @@ public EmbeddingResponse call(EmbeddingRequest request) { .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) .observe(() -> { - MiniMaxApi.EmbeddingList apiEmbeddingResponse = this.retryTemplate - .execute(ctx -> this.miniMaxApi.embeddings(apiRequest).getBody()); + MiniMaxApi.EmbeddingList apiEmbeddingResponse = null; + try { + apiEmbeddingResponse = this.retryTemplate + .execute(() -> this.miniMaxApi.embeddings(apiRequest).getBody()); + } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; + } + else { + throw new RuntimeException(e.getCause()); + } + } if (apiEmbeddingResponse == null) { logger.warn("No embeddings returned for request: {}", request); diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java index 9f165e27c0d..32d0baccc95 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java @@ -45,11 +45,11 @@ import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.retry.TransientAiException; +import org.springframework.core.retry.RetryListener; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.core.retry.Retryable; import org.springframework.http.ResponseEntity; -import org.springframework.retry.RetryCallback; -import org.springframework.retry.RetryContext; -import org.springframework.retry.RetryListener; -import org.springframework.retry.support.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -78,7 +78,7 @@ public class MiniMaxRetryTests { public void beforeEach() { this.retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE; this.retryListener = new TestRetryListener(); - this.retryTemplate.registerListener(this.retryListener); + this.retryTemplate.setRetryListener(this.retryListener); this.chatModel = new MiniMaxChatModel(this.miniMaxApi, MiniMaxChatOptions.builder().build(), ToolCallingManager.builder().build(), this.retryTemplate); @@ -103,7 +103,7 @@ public void miniMaxChatTransientError() { assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getText()).isSameAs("Response"); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @@ -131,7 +131,7 @@ public void miniMaxChatStreamTransientError() { assertThat(result).isNotNull(); assertThat(result.collectList().block().get(0).getResult().getOutput().getText()).isSameAs("Response"); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @@ -158,7 +158,7 @@ public void miniMaxEmbeddingTransientError() { assertThat(result).isNotNull(); assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @@ -171,21 +171,22 @@ public void miniMaxEmbeddingNonTransientError() { .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), options))); } - private class TestRetryListener implements RetryListener { + private static class TestRetryListener implements RetryListener { int onErrorRetryCount = 0; int onSuccessRetryCount = 0; @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - this.onSuccessRetryCount = context.getRetryCount(); + public void beforeRetry(final RetryPolicy retryPolicy, final Retryable retryable) { + // Count each retry attempt + this.onErrorRetryCount++; } @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - this.onErrorRetryCount = context.getRetryCount(); + public void onRetrySuccess(final RetryPolicy retryPolicy, final Retryable retryable, final Object result) { + // Count successful retries - we increment when we succeed after a failure + this.onSuccessRetryCount++; } } diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/MiniMaxEmbeddingModelObservationIT.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/MiniMaxEmbeddingModelObservationIT.java index 0b5df195dd3..b502515ae15 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/MiniMaxEmbeddingModelObservationIT.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/embedding/MiniMaxEmbeddingModelObservationIT.java @@ -37,7 +37,7 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; @@ -106,7 +106,7 @@ public MiniMaxApi minimaxApi() { public MiniMaxEmbeddingModel minimaxEmbeddingModel(MiniMaxApi minimaxApi, TestObservationRegistry observationRegistry) { return new MiniMaxEmbeddingModel(minimaxApi, MetadataMode.EMBED, MiniMaxEmbeddingOptions.builder().build(), - RetryTemplate.defaultInstance(), observationRegistry); + new RetryTemplate(), observationRegistry); } } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index f7314603ec3..3d03ff182e2 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -69,8 +69,9 @@ import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.support.UsageCalculator; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.core.retry.RetryException; +import org.springframework.core.retry.RetryTemplate; import org.springframework.http.ResponseEntity; -import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; @@ -199,8 +200,19 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons this.observationRegistry) .observe(() -> { - ResponseEntity completionEntity = this.retryTemplate - .execute(ctx -> this.mistralAiApi.chatCompletionEntity(request)); + ResponseEntity completionEntity = null; + try { + completionEntity = this.retryTemplate + .execute(() -> this.mistralAiApi.chatCompletionEntity(request)); + } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; + } + else { + throw new RuntimeException(e.getCause()); + } + } ChatCompletion chatCompletion = completionEntity.getBody(); @@ -272,8 +284,18 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); - Flux completionChunks = this.retryTemplate - .execute(ctx -> this.mistralAiApi.chatCompletionStream(request)); + Flux completionChunks = null; + try { + completionChunks = this.retryTemplate.execute(() -> this.mistralAiApi.chatCompletionStream(request)); + } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; + } + else { + throw new RuntimeException(e.getCause()); + } + } // For chunked responses, only the first chunk contains the choice role. // The rest of the chunks with same ID share the same role. diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java index 8650fca10b7..a782ce6a179 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java @@ -39,7 +39,8 @@ import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.retry.RetryUtils; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryException; +import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; /** @@ -140,8 +141,19 @@ public EmbeddingResponse call(EmbeddingRequest request) { .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) .observe(() -> { - var apiEmbeddingResponse = this.retryTemplate - .execute(ctx -> this.mistralAiApi.embeddings(apiRequest).getBody()); + MistralAiApi.EmbeddingList apiEmbeddingResponse = null; + try { + apiEmbeddingResponse = this.retryTemplate + .execute(() -> this.mistralAiApi.embeddings(apiRequest).getBody()); + } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; + } + else { + throw new RuntimeException(e.getCause()); + } + } if (apiEmbeddingResponse == null) { logger.warn("No embeddings returned for request: {}", request); diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/moderation/MistralAiModerationModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/moderation/MistralAiModerationModel.java index 0717520d766..281dbb0a809 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/moderation/MistralAiModerationModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/moderation/MistralAiModerationModel.java @@ -34,8 +34,9 @@ import org.springframework.ai.moderation.ModerationResponse; import org.springframework.ai.moderation.ModerationResult; import org.springframework.ai.retry.RetryUtils; +import org.springframework.core.retry.RetryException; +import org.springframework.core.retry.RetryTemplate; import org.springframework.http.ResponseEntity; -import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationRequest; @@ -81,27 +82,39 @@ public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi, R @Override public ModerationResponse call(ModerationPrompt moderationPrompt) { - return this.retryTemplate.execute(ctx -> { + try { + return this.retryTemplate.execute(() -> { - var instructions = moderationPrompt.getInstructions().getText(); + var instructions = moderationPrompt.getInstructions().getText(); - var moderationRequest = new MistralAiModerationRequest(instructions); + var moderationRequest = new MistralAiModerationRequest(instructions); - if (this.defaultOptions != null) { - moderationRequest = ModelOptionsUtils.merge(this.defaultOptions, moderationRequest, - MistralAiModerationRequest.class); + if (this.defaultOptions != null) { + moderationRequest = ModelOptionsUtils.merge(this.defaultOptions, moderationRequest, + MistralAiModerationRequest.class); + } + else { + // moderationPrompt.getOptions() never null but model can be empty, + // cause + // by ModerationPrompt constructor + moderationRequest = ModelOptionsUtils.merge( + toMistralAiModerationOptions(moderationPrompt.getOptions()), moderationRequest, + MistralAiModerationRequest.class); + } + + var moderationResponseEntity = this.mistralAiModerationApi.moderate(moderationRequest); + + return convertResponse(moderationResponseEntity, moderationRequest); + }); + } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; } else { - // moderationPrompt.getOptions() never null but model can be empty, cause - // by ModerationPrompt constructor - moderationRequest = ModelOptionsUtils.merge(toMistralAiModerationOptions(moderationPrompt.getOptions()), - moderationRequest, MistralAiModerationRequest.class); + throw new RuntimeException(e.getCause()); } - - var moderationResponseEntity = this.mistralAiModerationApi.moderate(moderationRequest); - - return convertResponse(moderationResponseEntity, moderationRequest); - }); + } } private ModerationResponse convertResponse(ResponseEntity moderationResponseEntity, diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java index 5d8a30d309c..03ac09d993c 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java @@ -38,7 +38,7 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.util.StringUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -190,7 +190,7 @@ public MistralAiChatModel mistralAiChatModel(MistralAiApi mistralAiApi, return MistralAiChatModel.builder() .mistralAiApi(mistralAiApi) .defaultOptions(MistralAiChatOptions.builder().build()) - .retryTemplate(RetryTemplate.defaultInstance()) + .retryTemplate(new RetryTemplate()) .observationRegistry(observationRegistry) .build(); } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelObservationIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelObservationIT.java index 8341c0fa22e..f3287c9fc3e 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelObservationIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelObservationIT.java @@ -34,7 +34,7 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; @@ -110,7 +110,7 @@ public MistralAiEmbeddingModel mistralAiEmbeddingModel(MistralAiApi mistralAiApi return MistralAiEmbeddingModel.builder() .mistralAiApi(mistralAiApi) .options(MistralAiEmbeddingOptions.builder().build()) - .retryTemplate(RetryTemplate.defaultInstance()) + .retryTemplate(new RetryTemplate()) .observationRegistry(observationRegistry) .build(); } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java index 9709a3783ca..f3d38e813ed 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java @@ -40,11 +40,11 @@ import org.springframework.ai.mistralai.api.MistralAiApi.EmbeddingRequest; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.retry.TransientAiException; +import org.springframework.core.retry.RetryListener; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.core.retry.Retryable; import org.springframework.http.ResponseEntity; -import org.springframework.retry.RetryCallback; -import org.springframework.retry.RetryContext; -import org.springframework.retry.RetryListener; -import org.springframework.retry.support.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -75,7 +75,7 @@ public class MistralAiRetryTests { public void beforeEach() { this.retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE; this.retryListener = new TestRetryListener(); - this.retryTemplate.registerListener(this.retryListener); + this.retryTemplate.setRetryListener(this.retryListener); this.chatModel = MistralAiChatModel.builder() .mistralAiApi(this.mistralAiApi) @@ -110,7 +110,7 @@ public void mistralAiChatTransientError() { assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getText()).isSameAs("Response"); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @@ -139,7 +139,7 @@ public void mistralAiChatStreamTransientError() { assertThat(result).isNotNull(); assertThat(result.collectList().block().get(0).getResult().getOutput().getText()).isSameAs("Response"); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @@ -167,7 +167,7 @@ public void mistralAiEmbeddingTransientError() { assertThat(result).isNotNull(); assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @@ -189,7 +189,7 @@ public void mistralAiChatMixedTransientAndNonTransientErrors() { assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text"))); // Should have 1 retry attempt before hitting non-transient error - assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(1); } private static class TestRetryListener implements RetryListener { @@ -199,14 +199,15 @@ private static class TestRetryListener implements RetryListener { int onSuccessRetryCount = 0; @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - this.onSuccessRetryCount = context.getRetryCount(); + public void beforeRetry(final RetryPolicy retryPolicy, final Retryable retryable) { + // Count each retry attempt + this.onErrorRetryCount++; } @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - this.onErrorRetryCount = context.getRetryCount(); + public void onRetrySuccess(final RetryPolicy retryPolicy, final Retryable retryable, final Object result) { + // Count successful retries - we increment when we succeed after a failure + this.onSuccessRetryCount++; } } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 7cb87eb8f3b..969454ee8b3 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -70,7 +70,8 @@ import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.util.json.JsonParser; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryException; +import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -246,7 +247,18 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon this.observationRegistry) .observe(() -> { - OllamaApi.ChatResponse ollamaResponse = this.retryTemplate.execute(ctx -> this.chatApi.chat(request)); + OllamaApi.ChatResponse ollamaResponse = null; + try { + ollamaResponse = this.retryTemplate.execute(() -> this.chatApi.chat(request)); + } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; + } + else { + throw new RuntimeException(e.getCause()); + } + } List toolCalls = ollamaResponse.message().toolCalls() == null ? List.of() : ollamaResponse.message() diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java index 693f892e940..232269fa97f 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java @@ -35,10 +35,10 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.core.io.ClassPathResource; -import org.springframework.retry.RetryCallback; -import org.springframework.retry.RetryContext; -import org.springframework.retry.RetryListener; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryListener; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.core.retry.Retryable; import org.springframework.util.MimeTypeUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -93,19 +93,21 @@ public OllamaApi ollamaApi() { @Bean public OllamaChatModel ollamaChat(OllamaApi ollamaApi) { - RetryTemplate retryTemplate = RetryTemplate.builder() + RetryPolicy retryPolicy = RetryPolicy.builder() .maxAttempts(1) - .retryOn(TransientAiException.class) - .fixedBackoff(Duration.ofSeconds(1)) - .withListener(new RetryListener() { - - @Override - public void onError(RetryContext context, - RetryCallback callback, Throwable throwable) { - logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable); - } - }) + .includes(TransientAiException.class) + .delay(Duration.ofSeconds(1)) .build(); + + RetryTemplate retryTemplate = new RetryTemplate(retryPolicy); + retryTemplate.setRetryListener(new RetryListener() { + + @Override + public void onRetryFailure(final RetryPolicy policy, final Retryable retryable, + final Throwable throwable) { + logger.warn("Retry error. Retry count:" + (throwable.getSuppressed().length + 1), throwable); + } + }); return OllamaChatModel.builder() .ollamaApi(ollamaApi) .defaultOptions(OllamaChatOptions.builder().model(MODEL).temperature(0.9).build()) diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaRetryTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaRetryTests.java index 323c969c6fa..a744aead9d1 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaRetryTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaRetryTests.java @@ -35,10 +35,10 @@ import org.springframework.ai.retry.NonTransientAiException; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.retry.TransientAiException; -import org.springframework.retry.RetryCallback; -import org.springframework.retry.RetryContext; -import org.springframework.retry.RetryListener; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryListener; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.core.retry.Retryable; import org.springframework.web.client.ResourceAccessException; import static org.assertj.core.api.Assertions.assertThat; @@ -71,7 +71,7 @@ class OllamaRetryTests { public void beforeEach() { this.retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE; this.retryListener = new TestRetryListener(); - this.retryTemplate.registerListener(this.retryListener); + this.retryTemplate.setRetryListener(this.retryListener); this.chatModel = OllamaChatModel.builder() .ollamaApi(this.ollamaApi) @@ -96,7 +96,7 @@ void ollamaChatTransientError() { assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getText()).isSameAs("Response"); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @@ -130,7 +130,7 @@ void ollamaChatNonTransientErrorShouldNotRetry() { .hasMessage("Model not found"); assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(0); - assertThat(this.retryListener.onErrorRetryCount).isEqualTo(1); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(0); verify(this.ollamaApi, times(1)).chat(isA(OllamaApi.ChatRequest.class)); } @@ -202,14 +202,15 @@ private static class TestRetryListener implements RetryListener { int onSuccessRetryCount = 0; @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - this.onSuccessRetryCount = context.getRetryCount(); + public void beforeRetry(final RetryPolicy retryPolicy, final Retryable retryable) { + // Count each retry attempt + this.onErrorRetryCount++; } @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - this.onErrorRetryCount = context.getRetryCount(); + public void onRetrySuccess(final RetryPolicy retryPolicy, final Retryable retryable, final Object result) { + // Count successful retries - we increment when we succeed after a failure + this.onSuccessRetryCount++; } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechModel.java index 759eac07e09..d83e4c8caae 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechModel.java @@ -31,8 +31,8 @@ import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata; import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor; import org.springframework.ai.retry.RetryUtils; +import org.springframework.core.retry.RetryTemplate; import org.springframework.http.ResponseEntity; -import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -127,8 +127,13 @@ public SpeechResponse call(SpeechPrompt speechPrompt) { OpenAiAudioApi.SpeechRequest speechRequest = createRequest(speechPrompt); - ResponseEntity speechEntity = this.retryTemplate - .execute(ctx -> this.audioApi.createSpeech(speechRequest)); + ResponseEntity speechEntity; + try { + speechEntity = this.retryTemplate.execute(() -> this.audioApi.createSpeech(speechRequest)); + } + catch (Exception e) { + throw new RuntimeException("Error calling OpenAI audio speech API", e); + } var speech = speechEntity.getBody(); @@ -154,8 +159,13 @@ public Flux stream(SpeechPrompt speechPrompt) { OpenAiAudioApi.SpeechRequest speechRequest = createRequest(speechPrompt); - Flux> speechEntity = this.retryTemplate - .execute(ctx -> this.audioApi.stream(speechRequest)); + Flux> speechEntity; + try { + speechEntity = this.retryTemplate.execute(() -> this.audioApi.stream(speechRequest)); + } + catch (Exception e) { + throw new RuntimeException("Error calling OpenAI audio speech streaming API", e); + } return speechEntity.map(entity -> new SpeechResponse(new Speech(entity.getBody()), new OpenAiAudioSpeechResponseMetadata(OpenAiResponseHeaderExtractor.extractAiResponseHeaders(entity)))); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionModel.java index 365b25cffb8..db5508b23c0 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionModel.java @@ -30,8 +30,8 @@ import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor; import org.springframework.ai.retry.RetryUtils; import org.springframework.core.io.Resource; +import org.springframework.core.retry.RetryTemplate; import org.springframework.http.ResponseEntity; -import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; /** @@ -112,8 +112,14 @@ public AudioTranscriptionResponse call(AudioTranscriptionPrompt transcriptionPro if (request.responseFormat().isJsonType()) { - ResponseEntity transcriptionEntity = this.retryTemplate - .execute(ctx -> this.audioApi.createTranscription(request, StructuredResponse.class)); + ResponseEntity transcriptionEntity; + try { + transcriptionEntity = this.retryTemplate + .execute(() -> this.audioApi.createTranscription(request, StructuredResponse.class)); + } + catch (Exception e) { + throw new RuntimeException("Error calling OpenAI transcription API", e); + } var transcription = transcriptionEntity.getBody(); @@ -133,8 +139,14 @@ public AudioTranscriptionResponse call(AudioTranscriptionPrompt transcriptionPro } else { - ResponseEntity transcriptionEntity = this.retryTemplate - .execute(ctx -> this.audioApi.createTranscription(request, String.class)); + ResponseEntity transcriptionEntity; + try { + transcriptionEntity = this.retryTemplate + .execute(() -> this.audioApi.createTranscription(request, String.class)); + } + catch (Exception e) { + throw new RuntimeException("Error calling OpenAI transcription API", e); + } var transcription = transcriptionEntity.getBody(); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 246b7893c4a..3de07809929 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -22,7 +22,6 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.stream.Collectors; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; @@ -78,13 +77,13 @@ import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.core.io.ByteArrayResource; import org.springframework.core.io.Resource; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.http.HttpHeaders; import org.springframework.http.ResponseEntity; -import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; -import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; /** @@ -196,8 +195,14 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons this.observationRegistry) .observe(() -> { - ResponseEntity completionEntity = this.retryTemplate - .execute(ctx -> this.openAiApi.chatCompletionEntity(request, getAdditionalHttpHeaders(prompt))); + ResponseEntity completionEntity; + try { + completionEntity = this.retryTemplate + .execute(() -> this.openAiApi.chatCompletionEntity(request, getAdditionalHttpHeaders(prompt))); + } + catch (Exception e) { + throw new RuntimeException("Error calling OpenAI chat completion API", e); + } var chatCompletion = completionEntity.getBody(); @@ -402,14 +407,15 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha }); } - private MultiValueMap getAdditionalHttpHeaders(Prompt prompt) { + private HttpHeaders getAdditionalHttpHeaders(Prompt prompt) { Map headers = new HashMap<>(this.defaultOptions.getHttpHeaders()); if (prompt.getOptions() != null && prompt.getOptions() instanceof OpenAiChatOptions chatOptions) { headers.putAll(chatOptions.getHttpHeaders()); } - return CollectionUtils.toMultiValueMap( - headers.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> List.of(e.getValue())))); + HttpHeaders httpHeaders = new HttpHeaders(); + headers.forEach(httpHeaders::add); + return httpHeaders; } private Generation buildGeneration(Choice choice, Map metadata, ChatCompletionRequest request) { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java index 47c06ac5a72..d5c3cb347d4 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java @@ -42,7 +42,7 @@ import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList; import org.springframework.ai.openai.api.common.OpenAiApiConstants; import org.springframework.ai.retry.RetryUtils; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; /** @@ -164,8 +164,14 @@ public EmbeddingResponse call(EmbeddingRequest request) { .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) .observe(() -> { - EmbeddingList apiEmbeddingResponse = this.retryTemplate - .execute(ctx -> this.openAiApi.embeddings(apiRequest).getBody()); + EmbeddingList apiEmbeddingResponse; + try { + apiEmbeddingResponse = this.retryTemplate + .execute(() -> this.openAiApi.embeddings(apiRequest).getBody()); + } + catch (Exception e) { + throw new RuntimeException("Error calling OpenAI embedding API", e); + } if (apiEmbeddingResponse == null) { logger.warn("No embeddings returned for request: {}", request); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java index 68354662548..d850c6ceef2 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java @@ -38,8 +38,8 @@ import org.springframework.ai.openai.api.common.OpenAiApiConstants; import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata; import org.springframework.ai.retry.RetryUtils; +import org.springframework.core.retry.RetryTemplate; import org.springframework.http.ResponseEntity; -import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; /** @@ -141,8 +141,14 @@ public ImageResponse call(ImagePrompt imagePrompt) { .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) .observe(() -> { - ResponseEntity imageResponseEntity = this.retryTemplate - .execute(ctx -> this.openAiImageApi.createImage(imageRequest)); + ResponseEntity imageResponseEntity; + try { + imageResponseEntity = this.retryTemplate + .execute(() -> this.openAiImageApi.createImage(imageRequest)); + } + catch (Exception e) { + throw new RuntimeException("Error calling OpenAI image API", e); + } ImageResponse imageResponse = convertResponse(imageResponseEntity, imageRequest); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationModel.java index 8e00c24b430..5a22be70ad0 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationModel.java @@ -34,8 +34,8 @@ import org.springframework.ai.moderation.ModerationResult; import org.springframework.ai.openai.api.OpenAiModerationApi; import org.springframework.ai.retry.RetryUtils; +import org.springframework.core.retry.RetryTemplate; import org.springframework.http.ResponseEntity; -import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; /** @@ -77,28 +77,34 @@ public OpenAiModerationModel withDefaultOptions(OpenAiModerationOptions defaultO @Override public ModerationResponse call(ModerationPrompt moderationPrompt) { - return this.retryTemplate.execute(ctx -> { + try { + return this.retryTemplate.execute(() -> { - String instructions = moderationPrompt.getInstructions().getText(); + String instructions = moderationPrompt.getInstructions().getText(); - OpenAiModerationApi.OpenAiModerationRequest moderationRequest = new OpenAiModerationApi.OpenAiModerationRequest( - instructions); + OpenAiModerationApi.OpenAiModerationRequest moderationRequest = new OpenAiModerationApi.OpenAiModerationRequest( + instructions); - if (this.defaultOptions != null) { - moderationRequest = ModelOptionsUtils.merge(this.defaultOptions, moderationRequest, - OpenAiModerationApi.OpenAiModerationRequest.class); - } + if (this.defaultOptions != null) { + moderationRequest = ModelOptionsUtils.merge(this.defaultOptions, moderationRequest, + OpenAiModerationApi.OpenAiModerationRequest.class); + } - if (moderationPrompt.getOptions() != null) { - moderationRequest = ModelOptionsUtils.merge(toOpenAiModerationOptions(moderationPrompt.getOptions()), - moderationRequest, OpenAiModerationApi.OpenAiModerationRequest.class); - } + if (moderationPrompt.getOptions() != null) { + moderationRequest = ModelOptionsUtils.merge( + toOpenAiModerationOptions(moderationPrompt.getOptions()), moderationRequest, + OpenAiModerationApi.OpenAiModerationRequest.class); + } - ResponseEntity moderationResponseEntity = this.openAiModerationApi - .createModeration(moderationRequest); + ResponseEntity moderationResponseEntity = this.openAiModerationApi + .createModeration(moderationRequest); - return convertResponse(moderationResponseEntity, moderationRequest); - }); + return convertResponse(moderationResponseEntity, moderationRequest); + }); + } + catch (Exception e) { + throw new RuntimeException("Error calling OpenAI moderation API", e); + } } private ModerationResponse convertResponse( diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index f7cf3c8abcf..49acf9b7242 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -48,8 +48,6 @@ import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; @@ -102,7 +100,7 @@ public static Builder builder() { private final ApiKey apiKey; - private final MultiValueMap headers; + private final HttpHeaders headers; private final String completionsPath; @@ -127,8 +125,8 @@ public static Builder builder() { * @param webClientBuilder WebClient builder. * @param responseErrorHandler Response error handler. */ - public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap headers, String completionsPath, - String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, + public OpenAiApi(String baseUrl, ApiKey apiKey, HttpHeaders headers, String completionsPath, String embeddingsPath, + RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { this.baseUrl = baseUrl; this.apiKey = apiKey; @@ -145,7 +143,7 @@ public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap he Consumer finalHeaders = h -> { h.setContentType(MediaType.APPLICATION_JSON); h.set(HTTP_USER_AGENT_HEADER, SPRING_AI_USER_AGENT); - h.addAll(HttpHeaders.readOnlyHttpHeaders(headers)); + h.addAll(headers); }; this.restClient = restClientBuilder.clone() .baseUrl(baseUrl) @@ -170,9 +168,8 @@ public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap he * @param webClient WebClient instance. * @param responseErrorHandler Response error handler. */ - public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap headers, String completionsPath, - String embeddingsPath, ResponseErrorHandler responseErrorHandler, RestClient restClient, - WebClient webClient) { + public OpenAiApi(String baseUrl, ApiKey apiKey, HttpHeaders headers, String completionsPath, String embeddingsPath, + ResponseErrorHandler responseErrorHandler, RestClient restClient, WebClient webClient) { this.baseUrl = baseUrl; this.apiKey = apiKey; this.headers = headers; @@ -206,7 +203,7 @@ public static String getTextContent(List con * and headers. */ public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { - return chatCompletionEntity(chatRequest, new LinkedMultiValueMap<>()); + return chatCompletionEntity(chatRequest, new HttpHeaders()); } /** @@ -218,7 +215,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest * and headers. */ public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest, - MultiValueMap additionalHttpHeader) { + HttpHeaders additionalHttpHeader) { Assert.notNull(chatRequest, REQUEST_BODY_NULL_MESSAGE); Assert.isTrue(!chatRequest.stream(), STREAM_FALSE_MESSAGE); @@ -228,7 +225,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest return this.restClient.post() .uri(this.completionsPath) .headers(headers -> { - headers.addAll(HttpHeaders.readOnlyHttpHeaders(additionalHttpHeader)); + headers.addAll(additionalHttpHeader); addDefaultHeadersIfMissing(headers); }) .body(chatRequest) @@ -244,7 +241,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest * @return Returns a {@link Flux} stream from chat completion chunks. */ public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { - return chatCompletionStream(chatRequest, new LinkedMultiValueMap<>()); + return chatCompletionStream(chatRequest, new HttpHeaders()); } /** @@ -256,7 +253,7 @@ public Flux chatCompletionStream(ChatCompletionRequest chat * @return Returns a {@link Flux} stream from chat completion chunks. */ public Flux chatCompletionStream(ChatCompletionRequest chatRequest, - MultiValueMap additionalHttpHeader) { + HttpHeaders additionalHttpHeader) { Assert.notNull(chatRequest, REQUEST_BODY_NULL_MESSAGE); Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); @@ -267,7 +264,7 @@ public Flux chatCompletionStream(ChatCompletionRequest chat return this.webClient.post() .uri(this.completionsPath) .headers(headers -> { - headers.addAll(HttpHeaders.readOnlyHttpHeaders(additionalHttpHeader)); + headers.addAll(additionalHttpHeader); addDefaultHeadersIfMissing(headers); }) // @formatter:on .body(Mono.just(chatRequest), ChatCompletionRequest.class) @@ -352,7 +349,7 @@ public ResponseEntity> embeddings(EmbeddingRequest< } private void addDefaultHeadersIfMissing(HttpHeaders headers) { - if (null == headers.getFirst(HttpHeaders.AUTHORIZATION) && !(this.apiKey instanceof NoopApiKey)) { + if (headers.get(HttpHeaders.AUTHORIZATION) == null && !(this.apiKey instanceof NoopApiKey)) { headers.setBearerAuth(this.apiKey.getValue()); } } @@ -366,7 +363,7 @@ ApiKey getApiKey() { return this.apiKey; } - MultiValueMap getHeaders() { + HttpHeaders getHeaders() { return this.headers; } @@ -2024,7 +2021,8 @@ public Builder() { public Builder(OpenAiApi api) { this.baseUrl = api.getBaseUrl(); this.apiKey = api.getApiKey(); - this.headers = new LinkedMultiValueMap<>(api.getHeaders()); + this.headers = new HttpHeaders(); + this.headers.addAll(api.getHeaders()); this.completionsPath = api.getCompletionsPath(); this.embeddingsPath = api.getEmbeddingsPath(); this.restClientBuilder = api.restClient != null ? api.restClient.mutate() : RestClient.builder(); @@ -2036,7 +2034,7 @@ public Builder(OpenAiApi api) { private ApiKey apiKey; - private MultiValueMap headers = new LinkedMultiValueMap<>(); + private HttpHeaders headers = new HttpHeaders(); private String completionsPath = "/v1/chat/completions"; @@ -2065,7 +2063,7 @@ public Builder apiKey(String simpleApiKey) { return this; } - public Builder headers(MultiValueMap headers) { + public Builder headers(HttpHeaders headers) { Assert.notNull(headers, "headers cannot be null"); this.headers = headers; return this; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java index 786eaea18ee..2de106ab54b 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java @@ -69,9 +69,8 @@ public class OpenAiAudioApi { * @param webClientBuilder WebClient builder. * @param responseErrorHandler Response error handler. */ - public OpenAiAudioApi(String baseUrl, ApiKey apiKey, MultiValueMap headers, - RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, - ResponseErrorHandler responseErrorHandler) { + public OpenAiAudioApi(String baseUrl, ApiKey apiKey, HttpHeaders headers, RestClient.Builder restClientBuilder, + WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { Consumer authHeaders = h -> h.addAll(HttpHeaders.readOnlyHttpHeaders(headers)); @@ -852,7 +851,7 @@ public static final class Builder { private ApiKey apiKey; - private MultiValueMap headers = new LinkedMultiValueMap<>(); + private HttpHeaders headers = new HttpHeaders(); private RestClient.Builder restClientBuilder = RestClient.builder(); @@ -878,7 +877,7 @@ public Builder apiKey(String simpleApiKey) { return this; } - public Builder headers(MultiValueMap headers) { + public Builder headers(HttpHeaders headers) { Assert.notNull(headers, "headers cannot be null"); this.headers = headers; return this; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiFileApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiFileApi.java index e7fd7d06d27..98eb5ae15e7 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiFileApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiFileApi.java @@ -49,9 +49,9 @@ public class OpenAiFileApi { private final RestClient restClient; - public OpenAiFileApi(String baseUrl, ApiKey apiKey, MultiValueMap headers, - RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { - Consumer authHeaders = h -> h.addAll(HttpHeaders.readOnlyHttpHeaders(headers)); + public OpenAiFileApi(String baseUrl, ApiKey apiKey, HttpHeaders headers, RestClient.Builder restClientBuilder, + ResponseErrorHandler responseErrorHandler) { + Consumer authHeaders = h -> h.addAll(headers); this.restClient = restClientBuilder.clone() .baseUrl(baseUrl) @@ -364,7 +364,7 @@ public static final class Builder { private ApiKey apiKey; - private MultiValueMap headers = new LinkedMultiValueMap<>(); + private HttpHeaders headers = new HttpHeaders(); private RestClient.Builder restClientBuilder = RestClient.builder(); @@ -388,7 +388,7 @@ public Builder apiKey(String simpleApiKey) { return this; } - public Builder headers(MultiValueMap headers) { + public Builder headers(HttpHeaders headers) { Assert.notNull(headers, "headers cannot be null"); this.headers = headers; return this; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java index 7aec86b7267..378653f6e2f 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java @@ -31,8 +31,6 @@ import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; @@ -60,7 +58,7 @@ public class OpenAiImageApi { * @param restClientBuilder the rest client builder to use. * @param responseErrorHandler the response error handler to use. */ - public OpenAiImageApi(String baseUrl, ApiKey apiKey, MultiValueMap headers, String imagesPath, + public OpenAiImageApi(String baseUrl, ApiKey apiKey, HttpHeaders headers, String imagesPath, RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { // @formatter:off @@ -178,7 +176,7 @@ public static final class Builder { private ApiKey apiKey; - private MultiValueMap headers = new LinkedMultiValueMap<>(); + private HttpHeaders headers = new HttpHeaders(); private RestClient.Builder restClientBuilder = RestClient.builder(); @@ -210,7 +208,7 @@ public Builder apiKey(String simpleApiKey) { return this; } - public Builder headers(MultiValueMap headers) { + public Builder headers(HttpHeaders headers) { Assert.notNull(headers, "headers cannot be null"); this.headers = headers; return this; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java index d13d6a95a4c..8c97a90f7ff 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java @@ -29,8 +29,6 @@ import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; @@ -55,8 +53,8 @@ public class OpenAiModerationApi { * @param apiKey OpenAI apiKey. * @param restClientBuilder the rest client builder to use. */ - public OpenAiModerationApi(String baseUrl, ApiKey apiKey, MultiValueMap headers, - RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { + public OpenAiModerationApi(String baseUrl, ApiKey apiKey, HttpHeaders headers, RestClient.Builder restClientBuilder, + ResponseErrorHandler responseErrorHandler) { // @formatter:off this.restClient = restClientBuilder.clone() @@ -178,7 +176,7 @@ public static final class Builder { private ApiKey apiKey; - private MultiValueMap headers = new LinkedMultiValueMap<>(); + private HttpHeaders headers = new HttpHeaders(); private RestClient.Builder restClientBuilder = RestClient.builder(); @@ -202,7 +200,7 @@ public Builder apiKey(String simpleApiKey) { return this; } - public Builder headers(MultiValueMap headers) { + public Builder headers(HttpHeaders headers) { Assert.notNull(headers, "headers cannot be null"); this.headers = headers; return this; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java index 47708ccb3b7..b7ad61960a3 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiResponseHeaderExtractor.java @@ -71,22 +71,18 @@ public static RateLimit extractAiResponseHeaders(ResponseEntity response) { private static Duration getHeaderAsDuration(ResponseEntity response, String headerName) { var headers = response.getHeaders(); - if (null != headers.getFirst(headerName)) { - var values = headers.get(headerName); - if (!CollectionUtils.isEmpty(values)) { - return DurationFormatter.TIME_UNIT.parse(values.get(0)); - } + var values = headers.get(headerName); + if (!CollectionUtils.isEmpty(values)) { + return DurationFormatter.TIME_UNIT.parse(values.get(0)); } return null; } private static Long getHeaderAsLong(ResponseEntity response, String headerName) { var headers = response.getHeaders(); - if (null != headers.getFirst(headerName)) { - var values = headers.get(headerName); - if (!CollectionUtils.isEmpty(values)) { - return parseLong(headerName, values.get(0)); - } + var values = headers.get(headerName); + if (!CollectionUtils.isEmpty(values)) { + return parseLong(headerName, values.get(0)); } return null; } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java index ce54f0031f4..e3cb1514932 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java @@ -38,8 +38,6 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; @@ -67,7 +65,7 @@ void testMinimalBuilder() { @Test void testFullBuilder() { - MultiValueMap headers = new LinkedMultiValueMap<>(); + HttpHeaders headers = new HttpHeaders(); headers.add("Custom-Header", "test-value"); RestClient.Builder restClientBuilder = RestClient.builder(); WebClient.Builder webClientBuilder = WebClient.builder(); @@ -193,7 +191,7 @@ void testNullApiKeyValue() { @Test void testBuilderMethodChaining() { - MultiValueMap headers = new LinkedMultiValueMap<>(); + HttpHeaders headers = new HttpHeaders(); headers.add("Test-Header", "test-value"); OpenAiApi api = OpenAiApi.builder() @@ -212,7 +210,7 @@ void testBuilderMethodChaining() { @Test void testCustomHeadersPreservation() { - MultiValueMap customHeaders = new LinkedMultiValueMap<>(); + HttpHeaders customHeaders = new HttpHeaders(); customHeaders.add("X-Custom-Header", "custom-value"); customHeaders.add("X-Organization", "org-123"); customHeaders.add("User-Agent", "Custom-Client/1.0"); @@ -224,7 +222,7 @@ void testCustomHeadersPreservation() { @Test void testComplexMultiValueHeaders() { - MultiValueMap multiHeaders = new LinkedMultiValueMap<>(); + HttpHeaders multiHeaders = new HttpHeaders(); multiHeaders.add("Accept", "application/json"); multiHeaders.add("Accept", "text/plain"); multiHeaders.add("Cache-Control", "no-cache"); @@ -278,7 +276,7 @@ void testDifferentApiKeyTypes() { @Test void testBuilderCreatesIndependentInstances() { - MultiValueMap sharedHeaders = new LinkedMultiValueMap<>(); + HttpHeaders sharedHeaders = new HttpHeaders(); sharedHeaders.add("X-Shared", "value"); OpenAiApi.Builder builder = OpenAiApi.builder() @@ -294,8 +292,8 @@ void testBuilderCreatesIndependentInstances() { OpenAiApi api2 = builder.build(); // Both APIs should have the modified headers since they share the same reference - assertThat(api1.getHeaders()).containsKey("X-Modified"); - assertThat(api2.getHeaders()).containsKey("X-Modified"); + assertThat(api1.getHeaders().containsHeader("X-Modified")).isTrue(); + assertThat(api2.getHeaders().containsHeader("X-Modified")).isTrue(); } @Test @@ -312,23 +310,23 @@ void testMutatePreservesResponseErrorHandler() { @Test void testMutateCreatesIndependentHeaders() { - MultiValueMap headers = new LinkedMultiValueMap<>(); + HttpHeaders headers = new HttpHeaders(); headers.add("X-Original", "value1"); OpenAiApi original = OpenAiApi.builder().apiKey(TEST_API_KEY).headers(headers).build(); - MultiValueMap newHeaders = new LinkedMultiValueMap<>(); + HttpHeaders newHeaders = new HttpHeaders(); newHeaders.add("X-New", "value2"); OpenAiApi mutated = original.mutate().headers(newHeaders).build(); // Original headers should be unchanged - assertThat(original.getHeaders()).containsKey("X-Original"); - assertThat(original.getHeaders()).doesNotContainKey("X-New"); + assertThat(original.getHeaders().containsHeader("X-Original")).isTrue(); + assertThat(original.getHeaders().containsHeader("X-New")).isFalse(); // Mutated should have new headers - assertThat(mutated.getHeaders()).doesNotContainKey("X-Original"); - assertThat(mutated.getHeaders()).containsKey("X-New"); + assertThat(mutated.getHeaders().containsHeader("X-Original")).isFalse(); + assertThat(mutated.getHeaders().containsHeader("X-New")).isTrue(); } @Test @@ -446,7 +444,7 @@ void dynamicApiKeyRestClientWithAdditionalAuthorizationHeader() throws Interrupt OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest( List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, false); - MultiValueMap additionalHeaders = new LinkedMultiValueMap<>(); + HttpHeaders additionalHeaders = new HttpHeaders(); additionalHeaders.add(HttpHeaders.AUTHORIZATION, "Bearer additional-key"); ResponseEntity response = api.chatCompletionEntity(request, additionalHeaders); assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); @@ -543,7 +541,7 @@ void dynamicApiKeyWebClientWithAdditionalAuthorizationHeader() throws Interrupte OpenAiApi.ChatCompletionMessage.Role.USER); OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest( List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, true); - MultiValueMap additionalHeaders = new LinkedMultiValueMap<>(); + HttpHeaders additionalHeaders = new HttpHeaders(); additionalHeaders.add(HttpHeaders.AUTHORIZATION, "Bearer additional-key"); List response = api.chatCompletionStream(request, additionalHeaders) .collectList() diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiChatModelMutateTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiChatModelMutateTests.java index 36ca255710e..cd549a9bebb 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiChatModelMutateTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiChatModelMutateTests.java @@ -20,7 +20,7 @@ import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; -import org.springframework.util.LinkedMultiValueMap; +import org.springframework.http.HttpHeaders; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; @@ -88,12 +88,12 @@ void mutateDoesNotAffectOriginal() { @Test void mutateHeadersCreatesDistinctHeaders() { - OpenAiApi mutatedApi = this.baseApi.mutate() - .headers(new LinkedMultiValueMap<>(java.util.Map.of("X-Test", java.util.List.of("value")))) - .build(); + HttpHeaders headers = new HttpHeaders(); + headers.add("X-Test", "value"); + OpenAiApi mutatedApi = this.baseApi.mutate().headers(headers).build(); - assertThat(mutatedApi.getHeaders()).containsKey("X-Test"); - assertThat(this.baseApi.getHeaders()).doesNotContainKey("X-Test"); + assertThat(mutatedApi.getHeaders().get("X-Test")).isNotNull(); + assertThat(this.baseApi.getHeaders().get("X-Test")).isNull(); } @Test @@ -129,7 +129,7 @@ void mutateAndCloneAreEquivalent() { @Test void testApiMutateWithComplexHeaders() { - LinkedMultiValueMap complexHeaders = new LinkedMultiValueMap<>(); + HttpHeaders complexHeaders = new HttpHeaders(); complexHeaders.add("Authorization", "Bearer custom-token"); complexHeaders.add("X-Custom-Header", "value1"); complexHeaders.add("X-Custom-Header", "value2"); @@ -137,9 +137,9 @@ void testApiMutateWithComplexHeaders() { OpenAiApi mutatedApi = this.baseApi.mutate().headers(complexHeaders).build(); - assertThat(mutatedApi.getHeaders()).containsKey("Authorization"); - assertThat(mutatedApi.getHeaders()).containsKey("X-Custom-Header"); - assertThat(mutatedApi.getHeaders()).containsKey("User-Agent"); + assertThat(mutatedApi.getHeaders().get("Authorization")).isNotNull(); + assertThat(mutatedApi.getHeaders().get("X-Custom-Header")).isNotNull(); + assertThat(mutatedApi.getHeaders().get("User-Agent")).isNotNull(); assertThat(mutatedApi.getHeaders().get("X-Custom-Header")).hasSize(2); } @@ -155,11 +155,11 @@ void testMutateWithEmptyOptions() { @Test void testApiMutateWithEmptyHeaders() { - LinkedMultiValueMap emptyHeaders = new LinkedMultiValueMap<>(); + HttpHeaders emptyHeaders = new HttpHeaders(); OpenAiApi mutatedApi = this.baseApi.mutate().headers(emptyHeaders).build(); - assertThat(mutatedApi.getHeaders()).isEmpty(); + assertThat(mutatedApi.getHeaders().isEmpty()).isTrue(); } @Test diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiFileApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiFileApiBuilderTests.java index 143fd9eaa68..0bcbf856db9 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiFileApiBuilderTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiFileApiBuilderTests.java @@ -36,8 +36,6 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; @@ -65,7 +63,7 @@ void testMinimalBuilder() { @Test void testFullBuilder() { - MultiValueMap headers = new LinkedMultiValueMap<>(); + HttpHeaders headers = new HttpHeaders(); headers.add("Custom-Header", "test-value"); RestClient.Builder restClientBuilder = RestClient.builder(); ResponseErrorHandler errorHandler = mock(ResponseErrorHandler.class); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiBuilderTests.java index ecd506277d3..2d567807abc 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiBuilderTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiBuilderTests.java @@ -37,8 +37,6 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; @@ -65,7 +63,7 @@ void testMinimalBuilder() { @Test void testFullBuilder() { - MultiValueMap headers = new LinkedMultiValueMap<>(); + HttpHeaders headers = new HttpHeaders(); headers.add("Custom-Header", "test-value"); RestClient.Builder restClientBuilder = RestClient.builder(); WebClient.Builder webClientBuilder = WebClient.builder(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiAudioTranscriptionModelTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiAudioTranscriptionModelTests.java index fba149ebe39..411a339fd69 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiAudioTranscriptionModelTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiAudioTranscriptionModelTests.java @@ -32,10 +32,10 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.ClassPathResource; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.test.web.client.MockRestServiceServer; -import org.springframework.util.LinkedMultiValueMap; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; @@ -123,9 +123,8 @@ static class Config { @Bean public OpenAiAudioApi openAiAudioApi(RestClient.Builder builder) { - return new OpenAiAudioApi("https://api.openai.com", new SimpleApiKey("test-api-key"), - new LinkedMultiValueMap<>(), builder, WebClient.builder(), - RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); + return new OpenAiAudioApi("https://api.openai.com", new SimpleApiKey("test-api-key"), new HttpHeaders(), + builder, WebClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); } @Bean diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java index 0c949a15f71..3a176ffaef4 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java @@ -41,10 +41,10 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.core.io.ByteArrayResource; +import org.springframework.http.HttpHeaders; import org.springframework.http.ResponseEntity; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; -import org.springframework.util.MultiValueMap; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.BDDMockito.given; @@ -66,7 +66,7 @@ public class MessageTypeContentTests { ArgumentCaptor pomptCaptor; @Captor - ArgumentCaptor> headersCaptor; + ArgumentCaptor headersCaptor; Flux fluxResponse = Flux.generate( () -> new ChatCompletionChunk("id", List.of(), 0L, "model", null, "fp", "object", null), (state, sink) -> { @@ -89,7 +89,7 @@ public void systemMessageSimpleContentType() { this.chatModel.call(new Prompt(List.of(new SystemMessage("test message")))); validateStringContent(this.pomptCaptor.getValue()); - assertThat(this.headersCaptor.getValue()).isEmpty(); + assertThat(this.headersCaptor.getValue().isEmpty()).isTrue(); } @Test @@ -112,7 +112,7 @@ public void streamUserMessageSimpleContentType() { this.chatModel.stream(new Prompt(List.of(new UserMessage("test message")))).subscribe(); validateStringContent(this.pomptCaptor.getValue()); - assertThat(this.headersCaptor.getValue()).isEmpty(); + assertThat(this.headersCaptor.getValue().isEmpty()).isTrue(); } private void validateStringContent(ChatCompletionRequest chatCompletionRequest) { @@ -207,7 +207,7 @@ public void userMessageWithEmptyMediaList() { .build()))); validateStringContent(this.pomptCaptor.getValue()); - assertThat(this.headersCaptor.getValue()).isEmpty(); + assertThat(this.headersCaptor.getValue().isEmpty()).isTrue(); } @Test @@ -308,7 +308,7 @@ public void streamWithMultipleMessagesAndMedia() { // User message should be complex assertThat(request.messages().get(1).rawContent()).isInstanceOf(List.class); - assertThat(this.headersCaptor.getValue()).isEmpty(); + assertThat(this.headersCaptor.getValue().isEmpty()).isTrue(); } // Helper method for testing different image formats diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java index 3b73adf7f0b..fa2daa3962f 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java @@ -40,7 +40,8 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -175,7 +176,8 @@ public OpenAiApi openAiApi() { @Bean public OpenAiChatModel openAiChatModel(OpenAiApi openAiApi, TestObservationRegistry observationRegistry) { return new OpenAiChatModel(openAiApi, OpenAiChatOptions.builder().build(), - ToolCallingManager.builder().build(), RetryTemplate.defaultInstance(), observationRegistry); + ToolCallingManager.builder().build(), new RetryTemplate(RetryPolicy.withDefaults()), + observationRegistry); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java index e19e82640b2..293437a0dc4 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java @@ -65,11 +65,11 @@ import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.retry.TransientAiException; import org.springframework.core.io.ClassPathResource; +import org.springframework.core.retry.RetryListener; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.core.retry.Retryable; import org.springframework.http.ResponseEntity; -import org.springframework.retry.RetryCallback; -import org.springframework.retry.RetryContext; -import org.springframework.retry.RetryListener; -import org.springframework.retry.support.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -109,7 +109,7 @@ public class OpenAiRetryTests { public void beforeEach() { this.retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE; this.retryListener = new TestRetryListener(); - this.retryTemplate.registerListener(this.retryListener); + this.retryTemplate.setRetryListener(this.retryListener); this.chatModel = OpenAiChatModel.builder() .openAiApi(this.openAiApi) @@ -145,8 +145,8 @@ public void openAiChatTransientError() { assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getText()).isSameAs("Response"); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); + assertThat(this.retryListener.retryCount).isEqualTo(2); } @Test @@ -174,8 +174,8 @@ public void openAiChatStreamTransientError() { assertThat(result).isNotNull(); assertThat(result.collectList().block().get(0).getResult().getOutput().getText()).isSameAs("Response"); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); + assertThat(this.retryListener.retryCount).isEqualTo(2); } @Test @@ -202,8 +202,8 @@ public void openAiEmbeddingTransientError() { assertThat(result).isNotNull(); assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); + assertThat(this.retryListener.retryCount).isEqualTo(2); } @Test @@ -229,8 +229,8 @@ public void openAiAudioTranscriptionTransientError() { assertThat(result).isNotNull(); assertThat(result.getResult().getOutput()).isEqualTo(expectedResponse.text()); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); + assertThat(this.retryListener.retryCount).isEqualTo(2); } @Test @@ -256,8 +256,8 @@ public void openAiImageTransientError() { assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getUrl()).isEqualTo("url678"); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); - assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); + assertThat(this.retryListener.retryCount).isEqualTo(2); } @Test @@ -270,19 +270,19 @@ public void openAiImageNonTransientError() { private static class TestRetryListener implements RetryListener { - int onErrorRetryCount = 0; + int retryCount = 0; int onSuccessRetryCount = 0; @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - this.onSuccessRetryCount = context.getRetryCount(); + public void onRetrySuccess(final RetryPolicy retryPolicy, final Retryable retryable, final Object result) { + // Count successful retries - we increment when we succeed after a failure + this.onSuccessRetryCount++; } @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - this.onErrorRetryCount = context.getRetryCount(); + public void beforeRetry(RetryPolicy retryPolicy, Retryable retryable) { + this.retryCount++; } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiStreamingFinishReasonTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiStreamingFinishReasonTests.java index 3dc59444e82..d1f8dee4929 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiStreamingFinishReasonTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiStreamingFinishReasonTests.java @@ -40,7 +40,7 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.ai.retry.RetryUtils; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingModelObservationIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingModelObservationIT.java index aa76a67f7a5..5d3cd477653 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingModelObservationIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingModelObservationIT.java @@ -37,7 +37,8 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; @@ -111,7 +112,7 @@ public OpenAiApi openAiApi() { public OpenAiEmbeddingModel openAiEmbeddingModel(OpenAiApi openAiApi, TestObservationRegistry observationRegistry) { return new OpenAiEmbeddingModel(openAiApi, MetadataMode.EMBED, OpenAiEmbeddingOptions.builder().build(), - RetryTemplate.defaultInstance(), observationRegistry); + new RetryTemplate(RetryPolicy.withDefaults()), observationRegistry); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelNoOpApiKeysIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelNoOpApiKeysIT.java index 61160a08a1f..bbd3373f826 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelNoOpApiKeysIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelNoOpApiKeysIT.java @@ -31,7 +31,8 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; @@ -70,7 +71,7 @@ public OpenAiImageApi openAiImageApi() { @Bean public OpenAiImageModel openAiImageModel(OpenAiImageApi openAiImageApi) { return new OpenAiImageModel(openAiImageApi, OpenAiImageOptions.builder().build(), - RetryTemplate.defaultInstance(), TestObservationRegistry.create()); + new RetryTemplate(RetryPolicy.withDefaults()), TestObservationRegistry.create()); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java index 37dc7abcdba..8a6fb12b13b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageModelObservationIT.java @@ -34,7 +34,8 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.image.observation.ImageModelObservationDocumentation.HighCardinalityKeyNames; @@ -105,7 +106,7 @@ public OpenAiImageApi openAiImageApi() { public OpenAiImageModel openAiImageModel(OpenAiImageApi openAiImageApi, TestObservationRegistry observationRegistry) { return new OpenAiImageModel(openAiImageApi, OpenAiImageOptions.builder().build(), - RetryTemplate.defaultInstance(), observationRegistry); + new RetryTemplate(RetryPolicy.withDefaults()), observationRegistry); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/api/OpenAiImageApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/api/OpenAiImageApiBuilderTests.java index 50bfd71fef3..6121893973a 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/api/OpenAiImageApiBuilderTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/api/OpenAiImageApiBuilderTests.java @@ -37,8 +37,6 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; @@ -64,7 +62,7 @@ void testMinimalBuilder() { @Test void testFullBuilder() { - MultiValueMap headers = new LinkedMultiValueMap<>(); + HttpHeaders headers = new HttpHeaders(); headers.add("Custom-Header", "test-value"); RestClient.Builder restClientBuilder = RestClient.builder(); ResponseErrorHandler errorHandler = mock(ResponseErrorHandler.class); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/api/OpenAiModerationApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/api/OpenAiModerationApiBuilderTests.java index 262eb21e05b..e7712e97d97 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/api/OpenAiModerationApiBuilderTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/api/OpenAiModerationApiBuilderTests.java @@ -37,8 +37,6 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; @@ -64,7 +62,7 @@ void testMinimalBuilder() { @Test void testFullBuilder() { - MultiValueMap headers = new LinkedMultiValueMap<>(); + HttpHeaders headers = new HttpHeaders(); headers.add("Custom-Header", "test-value"); RestClient.Builder restClientBuilder = RestClient.builder(); ResponseErrorHandler errorHandler = mock(ResponseErrorHandler.class); diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java index 4bef9d1145b..380ed5035ae 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java @@ -50,7 +50,8 @@ import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextInstanceBuilder; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextParametersBuilder; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryException; +import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -139,7 +140,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { (VertexAiTextEmbeddingOptions) options); PredictResponse embeddingResponse = this.retryTemplate - .execute(context -> getPredictResponse(client, predictRequestBuilder)); + .execute(() -> getPredictResponse(client, predictRequestBuilder)); int index = 0; int totalTokenCount = 0; @@ -163,6 +164,14 @@ public EmbeddingResponse call(EmbeddingRequest request) { return response; } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; + } + else { + throw new RuntimeException(e.getCause()); + } + } }); } diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java index e8627f3d625..36853d0bca7 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java @@ -23,7 +23,7 @@ import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; public class TestVertexAiTextEmbeddingModel extends VertexAiTextEmbeddingModel { diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java index 6d88ef6e958..08f4295dbcd 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java @@ -36,10 +36,10 @@ import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.retry.TransientAiException; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; -import org.springframework.retry.RetryCallback; -import org.springframework.retry.RetryContext; -import org.springframework.retry.RetryListener; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryListener; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.core.retry.Retryable; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; @@ -76,7 +76,7 @@ public class VertexAiTextEmbeddingRetryTests { public void setUp() { this.retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE; this.retryListener = new TestRetryListener(); - this.retryTemplate.registerListener(this.retryListener); + this.retryTemplate.setRetryListener(this.retryListener); this.embeddingModel = new TestVertexAiTextEmbeddingModel(this.mockConnectionDetails, VertexAiTextEmbeddingOptions.builder().build(), this.retryTemplate); @@ -123,7 +123,7 @@ public void vertexAiEmbeddingTransientError() { assertThat(result).isNotNull(); assertThat(result.getResults()).hasSize(1); assertThat(result.getResults().get(0).getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); verify(this.mockPredictRequestBuilder, times(3)).build(); @@ -163,14 +163,15 @@ private static class TestRetryListener implements RetryListener { int onSuccessRetryCount = 0; @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - this.onSuccessRetryCount = context.getRetryCount(); + public void beforeRetry(final RetryPolicy retryPolicy, final Retryable retryable) { + // Count each retry attempt + this.onErrorRetryCount++; } @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - this.onErrorRetryCount = context.getRetryCount(); + public void onRetrySuccess(final RetryPolicy retryPolicy, final Retryable retryable, final Object result) { + // Count successful retries - we increment when we succeed after a failure + this.onSuccessRetryCount++; } } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 3a55ee58611..c7b95ddad88 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -91,8 +91,9 @@ import org.springframework.ai.vertexai.gemini.schema.VertexAiSchemaConverter; import org.springframework.ai.vertexai.gemini.schema.VertexToolCallingManager; import org.springframework.beans.factory.DisposableBean; +import org.springframework.core.retry.RetryException; +import org.springframework.core.retry.RetryTemplate; import org.springframework.lang.NonNull; -import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -389,28 +390,45 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) - .observe(() -> this.retryTemplate.execute(context -> { - - var geminiRequest = createGeminiRequest(prompt); - - GenerateContentResponse generateContentResponse = this.getContentResponse(geminiRequest); + .observe(() -> { + try { + return this.retryTemplate.execute(() -> { + + var geminiRequest = createGeminiRequest(prompt); + + GenerateContentResponse generateContentResponse = this.getContentResponse(geminiRequest); + + List generations = generateContentResponse.getCandidatesList() + .stream() + .map(this::responseCandidateToGeneration) + .flatMap(List::stream) + .toList(); + + GenerateContentResponse.UsageMetadata usage = generateContentResponse.getUsageMetadata(); + Usage currentUsage = (usage != null) + ? new DefaultUsage(usage.getPromptTokenCount(), usage.getCandidatesTokenCount()) + : new EmptyUsage(); + Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(currentUsage, previousChatResponse); + ChatResponse chatResponse = new ChatResponse(generations, + toChatResponseMetadata(cumulativeUsage)); + + observationContext.setResponse(chatResponse); + return chatResponse; + }); + } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; + } - List generations = generateContentResponse.getCandidatesList() - .stream() - .map(this::responseCandidateToGeneration) - .flatMap(List::stream) - .toList(); - - GenerateContentResponse.UsageMetadata usage = generateContentResponse.getUsageMetadata(); - Usage currentUsage = (usage != null) - ? new DefaultUsage(usage.getPromptTokenCount(), usage.getCandidatesTokenCount()) - : new EmptyUsage(); - Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(currentUsage, previousChatResponse); - ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(cumulativeUsage)); - - observationContext.setResponse(chatResponse); - return chatResponse; - })); + if (e.getCause() instanceof RuntimeException r) { + throw r; + } + else { + throw new RuntimeException(e.getCause()); + } + } + }); if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java index 33af68c57d2..e19e39fccb0 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java @@ -23,7 +23,7 @@ import com.google.cloud.vertexai.generativeai.GenerativeModel; import org.springframework.ai.model.tool.ToolCallingManager; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; /** * @author Mark Pollack diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java index 79ac33982c3..2b1b88d9a76 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java @@ -35,10 +35,10 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.retry.TransientAiException; -import org.springframework.retry.RetryCallback; -import org.springframework.retry.RetryContext; -import org.springframework.retry.RetryListener; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryListener; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.core.retry.Retryable; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -68,7 +68,7 @@ public class VertexAiGeminiRetryTests { public void setUp() { this.retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE; this.retryListener = new TestRetryListener(); - this.retryTemplate.registerListener(this.retryListener); + this.retryTemplate.setRetryListener(this.retryListener); this.chatModel = new TestVertexAiGeminiChatModel(this.vertexAI, VertexAiGeminiChatOptions.builder() @@ -101,7 +101,7 @@ public void vertexAiGeminiChatTransientError() throws IOException { // Assertions assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getText()).isEqualTo("Response"); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @@ -168,7 +168,6 @@ public void vertexAiGeminiChatMaxRetriesExceeded() throws Exception { // Should throw the last TransientAiException after exhausting retries assertThrows(TransientAiException.class, () -> this.chatModel.call(new Prompt("test prompt"))); - // Verify retry attempts were made assertThat(this.retryListener.onErrorRetryCount).isGreaterThan(0); } @@ -249,7 +248,7 @@ public void vertexAiGeminiChatAlternatingErrorsAndSuccess() throws Exception { assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getText()).isEqualTo("Success after alternating errors"); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @@ -260,14 +259,15 @@ private static class TestRetryListener implements RetryListener { int onSuccessRetryCount = 0; @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - this.onSuccessRetryCount = context.getRetryCount(); + public void beforeRetry(final RetryPolicy retryPolicy, final Retryable retryable) { + // Count each retry attempt + this.onErrorRetryCount++; } @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - this.onErrorRetryCount = context.getRetryCount(); + public void onRetrySuccess(final RetryPolicy retryPolicy, final Retryable retryable, final Object result) { + // Count successful retries - we increment when we succeed after a failure + this.onSuccessRetryCount++; } } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java index 2c9ff3e54ff..4c6e0225bf9 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java @@ -70,8 +70,9 @@ import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage.ToolCall; import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionRequest; import org.springframework.ai.zhipuai.api.ZhiPuApiConstants; +import org.springframework.core.retry.RetryException; +import org.springframework.core.retry.RetryTemplate; import org.springframework.http.ResponseEntity; -import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; @@ -260,8 +261,18 @@ public ChatResponse call(Prompt prompt) { this.observationRegistry) .observe(() -> { - ResponseEntity completionEntity = this.retryTemplate - .execute(ctx -> this.zhiPuAiApi.chatCompletionEntity(request)); + ResponseEntity completionEntity = null; + try { + completionEntity = this.retryTemplate.execute(() -> this.zhiPuAiApi.chatCompletionEntity(request)); + } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; + } + else { + throw new RuntimeException(e.getCause()); + } + } var chatCompletion = completionEntity.getBody(); @@ -319,8 +330,18 @@ public Flux stream(Prompt prompt) { Prompt requestPrompt = buildRequestPrompt(prompt); ChatCompletionRequest request = createRequest(requestPrompt, true); - Flux completionChunks = this.retryTemplate - .execute(ctx -> this.zhiPuAiApi.chatCompletionStream(request)); + Flux completionChunks = null; + try { + completionChunks = this.retryTemplate.execute(() -> this.zhiPuAiApi.chatCompletionStream(request)); + } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; + } + else { + throw new RuntimeException(e.getCause()); + } + } // For chunked responses, only the first chunk contains the choice role. // The rest of the chunks with same ID share the same role. diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java index c1ec94262e1..7a60ddb883c 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java @@ -41,7 +41,9 @@ import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.zhipuai.api.ZhiPuAiApi; import org.springframework.ai.zhipuai.api.ZhiPuApiConstants; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryException; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -165,8 +167,19 @@ public EmbeddingResponse call(EmbeddingRequest request) { .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) .observe(() -> { - var embeddingResponse = this.retryTemplate - .execute(ctx -> this.zhiPuAiApi.embeddings(zhipuEmbeddingRequest)); + ResponseEntity> embeddingResponse = null; + try { + embeddingResponse = this.retryTemplate + .execute(() -> this.zhiPuAiApi.embeddings(zhipuEmbeddingRequest)); + } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; + } + else { + throw new RuntimeException(e.getCause()); + } + } if (embeddingResponse == null || embeddingResponse.getBody() == null || CollectionUtils.isEmpty(embeddingResponse.getBody().data())) { diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageModel.java index 406221e7d8a..e88231a9929 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageModel.java @@ -30,8 +30,9 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi; +import org.springframework.core.retry.RetryException; +import org.springframework.core.retry.RetryTemplate; import org.springframework.http.ResponseEntity; -import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; /** @@ -71,30 +72,40 @@ public ZhiPuAiImageOptions getDefaultOptions() { @Override public ImageResponse call(ImagePrompt imagePrompt) { - return this.retryTemplate.execute(ctx -> { + try { + return this.retryTemplate.execute(() -> { - String instructions = imagePrompt.getInstructions().get(0).getText(); + String instructions = imagePrompt.getInstructions().get(0).getText(); - ZhiPuAiImageApi.ZhiPuAiImageRequest imageRequest = new ZhiPuAiImageApi.ZhiPuAiImageRequest(instructions, - ZhiPuAiImageApi.DEFAULT_IMAGE_MODEL); + ZhiPuAiImageApi.ZhiPuAiImageRequest imageRequest = new ZhiPuAiImageApi.ZhiPuAiImageRequest(instructions, + ZhiPuAiImageApi.DEFAULT_IMAGE_MODEL); - if (this.defaultOptions != null) { - imageRequest = ModelOptionsUtils.merge(this.defaultOptions, imageRequest, - ZhiPuAiImageApi.ZhiPuAiImageRequest.class); - } + if (this.defaultOptions != null) { + imageRequest = ModelOptionsUtils.merge(this.defaultOptions, imageRequest, + ZhiPuAiImageApi.ZhiPuAiImageRequest.class); + } - if (imagePrompt.getOptions() != null) { - imageRequest = ModelOptionsUtils.merge(toZhiPuAiImageOptions(imagePrompt.getOptions()), imageRequest, - ZhiPuAiImageApi.ZhiPuAiImageRequest.class); - } + if (imagePrompt.getOptions() != null) { + imageRequest = ModelOptionsUtils.merge(toZhiPuAiImageOptions(imagePrompt.getOptions()), + imageRequest, ZhiPuAiImageApi.ZhiPuAiImageRequest.class); + } - // Make the request - ResponseEntity imageResponseEntity = this.zhiPuAiImageApi - .createImage(imageRequest); + // Make the request + ResponseEntity imageResponseEntity = this.zhiPuAiImageApi + .createImage(imageRequest); - // Convert to org.springframework.ai.model derived ImageResponse data type - return convertResponse(imageResponseEntity, imageRequest); - }); + // Convert to org.springframework.ai.model derived ImageResponse data type + return convertResponse(imageResponseEntity, imageRequest); + }); + } + catch (RetryException e) { + if (e.getCause() instanceof RuntimeException r) { + throw r; + } + else { + throw new RuntimeException(e.getCause()); + } + } } private ImageResponse convertResponse(ResponseEntity imageResponseEntity, diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java index 6c2ed11757b..ee1cabffe85 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java @@ -42,8 +42,6 @@ import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; @@ -85,7 +83,7 @@ public static Builder builder() { private final ApiKey apiKey; - private final MultiValueMap headers; + private final HttpHeaders headers; private final String completionsPath; @@ -143,7 +141,7 @@ public ZhiPuAiApi(String baseUrl, String zhiPuAiToken, RestClient.Builder restCl @Deprecated public ZhiPuAiApi(String baseUrl, String zhiPuAiToken, RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { - this(baseUrl, new SimpleApiKey(zhiPuAiToken), new LinkedMultiValueMap<>(), DEFAULT_COMPLETIONS_PATH, + this(baseUrl, new SimpleApiKey(zhiPuAiToken), new HttpHeaders(), DEFAULT_COMPLETIONS_PATH, DEFAULT_EMBEDDINGS_PATH, restClientBuilder, WebClient.builder(), responseErrorHandler); } @@ -158,7 +156,7 @@ public ZhiPuAiApi(String baseUrl, String zhiPuAiToken, RestClient.Builder restCl * @param webClientBuilder WebClient builder. * @param responseErrorHandler Response error handler. */ - private ZhiPuAiApi(String baseUrl, ApiKey apiKey, MultiValueMap headers, String completionsPath, + private ZhiPuAiApi(String baseUrl, ApiKey apiKey, HttpHeaders headers, String completionsPath, String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { Assert.hasText(completionsPath, "Completions Path must not be null"); @@ -174,7 +172,7 @@ private ZhiPuAiApi(String baseUrl, ApiKey apiKey, MultiValueMap Consumer authHeaders = h -> { h.setContentType(MediaType.APPLICATION_JSON); - h.addAll(HttpHeaders.readOnlyHttpHeaders(headers)); + h.addAll(headers); }; this.restClient = restClientBuilder.clone() @@ -201,9 +199,8 @@ private ZhiPuAiApi(String baseUrl, ApiKey apiKey, MultiValueMap * @param webClient WebClient instance. * @param responseErrorHandler Response error handler. */ - public ZhiPuAiApi(String baseUrl, ApiKey apiKey, MultiValueMap headers, String completionsPath, - String embeddingsPath, ResponseErrorHandler responseErrorHandler, RestClient restClient, - WebClient webClient) { + public ZhiPuAiApi(String baseUrl, ApiKey apiKey, HttpHeaders headers, String completionsPath, String embeddingsPath, + ResponseErrorHandler responseErrorHandler, RestClient restClient, WebClient webClient) { Assert.hasText(completionsPath, "Completions Path must not be null"); Assert.hasText(embeddingsPath, "Embeddings Path must not be null"); Assert.notNull(headers, "Headers must not be null"); @@ -232,7 +229,7 @@ public static String getTextContent(List con * and headers. */ public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { - return chatCompletionEntity(chatRequest, new LinkedMultiValueMap<>()); + return chatCompletionEntity(chatRequest, new HttpHeaders()); } /** @@ -242,7 +239,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest * and headers. */ public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest, - MultiValueMap additionalHttpHeader) { + HttpHeaders additionalHttpHeader) { Assert.notNull(chatRequest, "The request body can not be null."); Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); @@ -251,7 +248,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest return this.restClient.post() .uri(this.completionsPath) .headers(headers -> { - headers.addAll(HttpHeaders.readOnlyHttpHeaders(additionalHttpHeader)); + headers.addAll(additionalHttpHeader); addDefaultHeadersIfMissing(headers); }) .body(chatRequest) @@ -267,7 +264,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest * @return Returns a {@link Flux} stream from chat completion chunks. */ public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { - return chatCompletionStream(chatRequest, new LinkedMultiValueMap<>()); + return chatCompletionStream(chatRequest, new HttpHeaders()); } /** @@ -277,7 +274,7 @@ public Flux chatCompletionStream(ChatCompletionRequest chat * @return Returns a {@link Flux} stream from chat completion chunks. */ public Flux chatCompletionStream(ChatCompletionRequest chatRequest, - MultiValueMap additionalHttpHeader) { + HttpHeaders additionalHttpHeader) { Assert.notNull(chatRequest, "The request body can not be null."); Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); @@ -288,7 +285,7 @@ public Flux chatCompletionStream(ChatCompletionRequest chat return this.webClient.post() .uri(this.completionsPath) .headers(headers -> { - headers.addAll(HttpHeaders.readOnlyHttpHeaders(additionalHttpHeader)); + headers.addAll(additionalHttpHeader); addDefaultHeadersIfMissing(headers); }) // @formatter:on .body(Mono.just(chatRequest), ChatCompletionRequest.class) @@ -358,7 +355,7 @@ public ResponseEntity> embeddings(EmbeddingRequest< } private void addDefaultHeadersIfMissing(HttpHeaders headers) { - if (null == headers.getFirst(HttpHeaders.AUTHORIZATION) && !(this.apiKey instanceof NoopApiKey)) { + if (headers.get(HttpHeaders.AUTHORIZATION) == null && !(this.apiKey instanceof NoopApiKey)) { headers.setBearerAuth(this.apiKey.getValue()); } } @@ -372,7 +369,7 @@ ApiKey getApiKey() { return this.apiKey; } - MultiValueMap getHeaders() { + HttpHeaders getHeaders() { return this.headers; } @@ -1244,7 +1241,8 @@ private Builder() { public Builder(ZhiPuAiApi api) { this.baseUrl = api.getBaseUrl(); this.apiKey = api.getApiKey(); - this.headers = new LinkedMultiValueMap<>(api.getHeaders()); + this.headers = new HttpHeaders(); + this.headers.addAll(api.getHeaders()); this.completionsPath = api.getCompletionsPath(); this.embeddingsPath = api.getEmbeddingsPath(); this.restClientBuilder = api.restClient != null ? api.restClient.mutate() : RestClient.builder(); @@ -1256,7 +1254,7 @@ public Builder(ZhiPuAiApi api) { private ApiKey apiKey; - private MultiValueMap headers = new LinkedMultiValueMap<>(); + private HttpHeaders headers = new HttpHeaders(); private String completionsPath = DEFAULT_COMPLETIONS_PATH; @@ -1285,7 +1283,7 @@ public Builder apiKey(String simpleApiKey) { return this; } - public Builder headers(MultiValueMap headers) { + public Builder headers(HttpHeaders headers) { Assert.notNull(headers, "headers cannot be null"); this.headers = headers; return this; diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiBuilderTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiBuilderTests.java index a6409e70c20..5b1f0d9795a 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiBuilderTests.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiBuilderTests.java @@ -37,8 +37,6 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; @@ -66,7 +64,7 @@ void testMinimalBuilder() { @Test void testFullBuilder() { - MultiValueMap headers = new LinkedMultiValueMap<>(); + var headers = new HttpHeaders(); headers.add("Custom-Header", "test-value"); RestClient.Builder restClientBuilder = RestClient.builder(); WebClient.Builder webClientBuilder = WebClient.builder(); @@ -232,7 +230,7 @@ void dynamicApiKeyRestClientWithAdditionalAuthorizationHeader() throws Interrupt ZhiPuAiApi.ChatCompletionRequest request = new ZhiPuAiApi.ChatCompletionRequest( List.of(chatCompletionMessage), "glm-4-flash", 0.8, false); - MultiValueMap additionalHeaders = new LinkedMultiValueMap<>(); + var additionalHeaders = new HttpHeaders(); additionalHeaders.add(HttpHeaders.AUTHORIZATION, "Bearer additional-key"); ResponseEntity response = api.chatCompletionEntity(request, additionalHeaders); assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); @@ -289,7 +287,7 @@ void dynamicApiKeyWebClientWithAdditionalAuthorizationHeader() throws Interrupte ZhiPuAiApi.ChatCompletionMessage.Role.USER); ZhiPuAiApi.ChatCompletionRequest request = new ZhiPuAiApi.ChatCompletionRequest( List.of(chatCompletionMessage), "glm-4-flash", 0.8, true); - MultiValueMap additionalHeaders = new LinkedMultiValueMap<>(); + var additionalHeaders = new HttpHeaders(); additionalHeaders.add(HttpHeaders.AUTHORIZATION, "Bearer additional-key"); List response = api.chatCompletionStream(request, additionalHeaders) .collectList() diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java index 327a7f45329..7746360829c 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java @@ -52,11 +52,11 @@ import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi.Data; import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi.ZhiPuAiImageRequest; import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi.ZhiPuAiImageResponse; +import org.springframework.core.retry.RetryListener; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.core.retry.Retryable; import org.springframework.http.ResponseEntity; -import org.springframework.retry.RetryCallback; -import org.springframework.retry.RetryContext; -import org.springframework.retry.RetryListener; -import org.springframework.retry.support.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -88,7 +88,7 @@ public class ZhiPuAiRetryTests { public void beforeEach() { this.retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE; this.retryListener = new TestRetryListener(); - this.retryTemplate.registerListener(this.retryListener); + this.retryTemplate.setRetryListener(this.retryListener); this.chatModel = new ZhiPuAiChatModel(this.zhiPuAiApi, ZhiPuAiChatOptions.builder().build(), this.retryTemplate); @@ -115,7 +115,7 @@ public void zhiPuAiChatTransientError() { assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getText()).isSameAs("Response"); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @@ -144,7 +144,7 @@ public void zhiPuAiChatStreamTransientError() { assertThat(result).isNotNull(); assertThat(result.collectList().block().get(0).getResult().getOutput().getText()).isSameAs("Response"); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @@ -174,7 +174,7 @@ public void zhiPuAiEmbeddingTransientError() { assertThat(result).isNotNull(); // choose the first result assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @@ -201,7 +201,7 @@ public void zhiPuAiImageTransientError() { assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getUrl()).isEqualTo("url678"); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1); assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } @@ -213,21 +213,22 @@ public void zhiPuAiImageNonTransientError() { () -> this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))))); } - private class TestRetryListener implements RetryListener { + private static class TestRetryListener implements RetryListener { int onErrorRetryCount = 0; int onSuccessRetryCount = 0; @Override - public void onSuccess(RetryContext context, RetryCallback callback, T result) { - this.onSuccessRetryCount = context.getRetryCount(); + public void beforeRetry(final RetryPolicy retryPolicy, final Retryable retryable) { + // Count each retry attempt + this.onErrorRetryCount++; } @Override - public void onError(RetryContext context, RetryCallback callback, - Throwable throwable) { - this.onErrorRetryCount = context.getRetryCount(); + public void onRetrySuccess(final RetryPolicy retryPolicy, final Retryable retryable, final Object result) { + // Count successful retries - we increment when we succeed after a failure + this.onSuccessRetryCount++; } } diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java index 953c7c3bb4e..a01fc35ec0d 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java @@ -39,7 +39,7 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -164,8 +164,8 @@ public ZhiPuAiApi zhiPuAiApi() { @Bean public ZhiPuAiChatModel zhiPuAiChatModel(ZhiPuAiApi zhiPuAiApi, TestObservationRegistry observationRegistry) { - return new ZhiPuAiChatModel(zhiPuAiApi, ZhiPuAiChatOptions.builder().build(), - RetryTemplate.defaultInstance(), observationRegistry); + return new ZhiPuAiChatModel(zhiPuAiApi, ZhiPuAiChatOptions.builder().build(), new RetryTemplate(), + observationRegistry); } } diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/ZhiPuAiEmbeddingModelObservationIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/ZhiPuAiEmbeddingModelObservationIT.java index f6a33037566..11f3bc70f75 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/ZhiPuAiEmbeddingModelObservationIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/ZhiPuAiEmbeddingModelObservationIT.java @@ -37,7 +37,7 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; +import org.springframework.core.retry.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; @@ -107,7 +107,7 @@ public ZhiPuAiEmbeddingModel zhiPuAiEmbeddingModel(ZhiPuAiApi zhiPuAiApi, TestObservationRegistry observationRegistry) { return new ZhiPuAiEmbeddingModel(zhiPuAiApi, MetadataMode.EMBED, ZhiPuAiEmbeddingOptions.builder().model(ZhiPuAiApi.DEFAULT_EMBEDDING_MODEL).build(), - RetryTemplate.defaultInstance(), observationRegistry); + new RetryTemplate(), observationRegistry); } } diff --git a/pom.xml b/pom.xml index ca95c983750..aff74d8e334 100644 --- a/pom.xml +++ b/pom.xml @@ -329,6 +329,7 @@ 4.12.0 + 5.5.6 4.1.0 @@ -363,7 +364,6 @@ false - 5.5.6 diff --git a/spring-ai-retry/pom.xml b/spring-ai-retry/pom.xml index cd9f1502d78..1706fe61276 100644 --- a/spring-ai-retry/pom.xml +++ b/spring-ai-retry/pom.xml @@ -42,13 +42,6 @@ - - org.springframework.retry - spring-retry - 2.0.12 - - - org.springframework spring-web diff --git a/spring-ai-retry/src/main/java/org/springframework/ai/retry/NonTransientAiException.java b/spring-ai-retry/src/main/java/org/springframework/ai/retry/NonTransientAiException.java index 44c405ca6d8..c82762f60e0 100644 --- a/spring-ai-retry/src/main/java/org/springframework/ai/retry/NonTransientAiException.java +++ b/spring-ai-retry/src/main/java/org/springframework/ai/retry/NonTransientAiException.java @@ -26,11 +26,20 @@ */ public class NonTransientAiException extends RuntimeException { - public NonTransientAiException(String message) { + /** + * Constructor with message. + * @param message the exception message + */ + public NonTransientAiException(final String message) { super(message); } - public NonTransientAiException(String message, Throwable cause) { + /** + * Constructor with message and cause. + * @param message the exception message + * @param cause the exception cause + */ + public NonTransientAiException(final String message, final Throwable cause) { super(message, cause); } diff --git a/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java b/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java index 2d7cbf66742..2a1f3c97c73 100644 --- a/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java +++ b/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java @@ -20,17 +20,18 @@ import java.net.URI; import java.nio.charset.StandardCharsets; import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.core.retry.RetryListener; +import org.springframework.core.retry.RetryPolicy; +import org.springframework.core.retry.RetryTemplate; +import org.springframework.core.retry.Retryable; import org.springframework.http.HttpMethod; import org.springframework.http.client.ClientHttpResponse; import org.springframework.lang.NonNull; -import org.springframework.retry.RetryCallback; -import org.springframework.retry.RetryContext; -import org.springframework.retry.RetryListener; -import org.springframework.retry.support.RetryTemplate; import org.springframework.util.StreamUtils; import org.springframework.web.client.ResourceAccessException; import org.springframework.web.client.ResponseErrorHandler; @@ -45,28 +46,44 @@ */ public abstract class RetryUtils { + private static final int DEFAULT_MAX_ATTEMPTS = 10; + + private static final long DEFAULT_INITIAL_INTERVAL = 2000; + + private static final int DEFAULT_MULTIPLIER = 5; + + private static final long DEFAULT_MAX_INTERVAL = 3 * 60000; + + private static final long SHORT_INITIAL_INTERVAL = 100; + + private static final Logger LOGGER = LoggerFactory.getLogger(RetryUtils.class); + + /** + * Default ResponseErrorHandler implementation. + */ public static final ResponseErrorHandler DEFAULT_RESPONSE_ERROR_HANDLER = new ResponseErrorHandler() { @Override - public boolean hasError(@NonNull ClientHttpResponse response) throws IOException { + public boolean hasError(final @NonNull ClientHttpResponse response) throws IOException { return response.getStatusCode().isError(); } @Override - public void handleError(URI url, HttpMethod method, @NonNull ClientHttpResponse response) throws IOException { + public void handleError(final URI url, final HttpMethod method, final @NonNull ClientHttpResponse response) + throws IOException { handleError(response); } @SuppressWarnings("removal") - public void handleError(@NonNull ClientHttpResponse response) throws IOException { + public void handleError(final @NonNull ClientHttpResponse response) throws IOException { if (response.getStatusCode().isError()) { String error = StreamUtils.copyToString(response.getBody(), StandardCharsets.UTF_8); String message = String.format("%s - %s", response.getStatusCode().value(), error); - /** + /* * Thrown on 4xx client errors, such as 401 - Incorrect API key provided, * 401 - You must be a member of an organization to use the API, 429 - - * Rate limit reached for requests, 429 - You exceeded your current quota - * , please check your plan and billing details. + * Rate limit reached for requests, 429 - You exceeded your current quota, + * please check your plan and billing details. */ if (response.getStatusCode().is4xxClientError()) { throw new NonTransientAiException(message); @@ -74,42 +91,68 @@ public void handleError(@NonNull ClientHttpResponse response) throws IOException throw new TransientAiException(message); } } + }; - private static final Logger logger = LoggerFactory.getLogger(RetryUtils.class); + /** + * Default RetryTemplate with exponential backoff configuration. + */ + public static final RetryTemplate DEFAULT_RETRY_TEMPLATE = createDefaultRetryTemplate(); + + /** + * Short RetryTemplate for testing scenarios. + */ + public static final RetryTemplate SHORT_RETRY_TEMPLATE = createShortRetryTemplate(); - public static final RetryTemplate DEFAULT_RETRY_TEMPLATE = RetryTemplate.builder() - .maxAttempts(10) - .retryOn(TransientAiException.class) - .retryOn(ResourceAccessException.class) - .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000L)) - .withListener(new RetryListener() { + private static RetryTemplate createDefaultRetryTemplate() { + RetryPolicy retryPolicy = RetryPolicy.builder() + .maxAttempts(DEFAULT_MAX_ATTEMPTS) + .includes(TransientAiException.class) + .includes(ResourceAccessException.class) + .delay(Duration.ofMillis(DEFAULT_INITIAL_INTERVAL)) + .multiplier(DEFAULT_MULTIPLIER) + .maxDelay(Duration.ofMillis(DEFAULT_MAX_INTERVAL)) + .build(); + + RetryTemplate retryTemplate = new RetryTemplate(retryPolicy); + retryTemplate.setRetryListener(new RetryListener() { + private final AtomicInteger retryCount = new AtomicInteger(0); @Override - public void onError(RetryContext context, - RetryCallback callback, Throwable throwable) { - logger.warn("Retry error. Retry count:{}", context.getRetryCount(), throwable); + public void onRetryFailure(final RetryPolicy policy, final Retryable retryable, + final Throwable throwable) { + int currentRetries = this.retryCount.incrementAndGet(); + LOGGER.warn("Retry error. Retry count:{}", currentRetries, throwable); } - }) - .build(); + }); + return retryTemplate; + } /** - * Useful in testing scenarios where you don't want to wait long for retry and now - * show stack trace + * Useful in testing scenarios where you don't want to wait long for retry and don't + * need to show stack trace. + * @return a RetryTemplate with short delays */ - public static final RetryTemplate SHORT_RETRY_TEMPLATE = RetryTemplate.builder() - .maxAttempts(10) - .retryOn(TransientAiException.class) - .retryOn(ResourceAccessException.class) - .fixedBackoff(Duration.ofMillis(100)) - .withListener(new RetryListener() { + private static RetryTemplate createShortRetryTemplate() { + RetryPolicy retryPolicy = RetryPolicy.builder() + .maxAttempts(DEFAULT_MAX_ATTEMPTS) + .includes(TransientAiException.class) + .includes(ResourceAccessException.class) + .delay(Duration.ofMillis(SHORT_INITIAL_INTERVAL)) + .build(); + + RetryTemplate retryTemplate = new RetryTemplate(retryPolicy); + retryTemplate.setRetryListener(new RetryListener() { + private final AtomicInteger retryCount = new AtomicInteger(0); @Override - public void onError(RetryContext context, - RetryCallback callback, Throwable throwable) { - logger.warn("Retry error. Retry count:{}", context.getRetryCount()); + public void onRetryFailure(final RetryPolicy policy, final Retryable retryable, + final Throwable throwable) { + int currentRetries = this.retryCount.incrementAndGet(); + LOGGER.warn("Retry error. Retry count:{}", currentRetries, throwable); } - }) - .build(); + }); + return retryTemplate; + } } diff --git a/spring-ai-retry/src/main/java/org/springframework/ai/retry/TransientAiException.java b/spring-ai-retry/src/main/java/org/springframework/ai/retry/TransientAiException.java index 95b6e37f668..90d43fe0e5c 100644 --- a/spring-ai-retry/src/main/java/org/springframework/ai/retry/TransientAiException.java +++ b/spring-ai-retry/src/main/java/org/springframework/ai/retry/TransientAiException.java @@ -19,18 +19,27 @@ /** * Root of the hierarchy of Model access exceptions that are considered transient - where * a previously failed operation might be able to succeed when the operation is retried - * without any intervention by application-level functionality. + * without any intervention. * * @author Christian Tzolov * @since 0.8.1 */ public class TransientAiException extends RuntimeException { - public TransientAiException(String message) { + /** + * Constructor with message. + * @param message the exception message + */ + public TransientAiException(final String message) { super(message); } - public TransientAiException(String message, Throwable cause) { + /** + * Constructor with message and cause. + * @param message the exception message + * @param cause the exception cause + */ + public TransientAiException(final String message, final Throwable cause) { super(message, cause); } diff --git a/spring-ai-retry/src/test/java/org/springframework/ai/retry/RetryUtilsTests.java b/spring-ai-retry/src/test/java/org/springframework/ai/retry/RetryUtilsTests.java index 1541edbc47a..3007ec16bfc 100644 --- a/spring-ai-retry/src/test/java/org/springframework/ai/retry/RetryUtilsTests.java +++ b/spring-ai-retry/src/test/java/org/springframework/ai/retry/RetryUtilsTests.java @@ -24,10 +24,11 @@ import org.junit.jupiter.api.Test; +import org.springframework.core.retry.RetryException; +import org.springframework.core.retry.RetryTemplate; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.client.ClientHttpResponse; -import org.springframework.retry.support.RetryTemplate; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -98,20 +99,20 @@ void shortRetryTemplateRetries() { AtomicInteger counter = new AtomicInteger(0); RetryTemplate template = RetryUtils.SHORT_RETRY_TEMPLATE; - assertThrows(TransientAiException.class, () -> template.execute(cb -> { + assertThrows(RetryException.class, () -> template.execute(() -> { counter.incrementAndGet(); throw new TransientAiException("test fail"); })); - assertEquals(10, counter.get()); + assertEquals(11, counter.get()); } @Test - void shortRetryTemplateSucceedsBeforeMaxAttempts() { + void shortRetryTemplateSucceedsBeforeMaxAttempts() throws RetryException { AtomicInteger counter = new AtomicInteger(0); RetryTemplate template = RetryUtils.SHORT_RETRY_TEMPLATE; - String result = template.execute(cb -> { + String result = template.execute(() -> { if (counter.incrementAndGet() < 5) { throw new TransientAiException("test fail"); } diff --git a/spring-ai-spring-boot-docker-compose/pom.xml b/spring-ai-spring-boot-docker-compose/pom.xml index bf480475df6..891cd2b6033 100644 --- a/spring-ai-spring-boot-docker-compose/pom.xml +++ b/spring-ai-spring-boot-docker-compose/pom.xml @@ -201,7 +201,7 @@ true - + org.springframework.ai @@ -215,6 +215,7 @@ spring-boot-starter-web test + org.springframework.boot spring-boot-starter-test diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/docker/compose/service/connection/test/AbstractDockerComposeIT.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/docker/compose/service/connection/test/AbstractDockerComposeIT.java index 686626b3de3..1b42dcb9213 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/docker/compose/service/connection/test/AbstractDockerComposeIT.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/docker/compose/service/connection/test/AbstractDockerComposeIT.java @@ -33,7 +33,7 @@ import org.springframework.boot.autoconfigure.ImportAutoConfiguration; import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; import org.springframework.boot.testsupport.DisabledIfProcessUnavailable; -import org.springframework.boot.tomcat.autoconfigure.servlet.TomcatServletWebServerAutoConfiguration; +import org.springframework.boot.web.server.autoconfigure.servlet.ServletWebServerConfiguration; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; @@ -97,7 +97,7 @@ private File transformedComposeFile(File composeFile, DockerImageName imageName) } @Configuration(proxyBeanMethods = false) - @ImportAutoConfiguration(TomcatServletWebServerAutoConfiguration.class) + @ImportAutoConfiguration(ServletWebServerConfiguration.class) static class Config { } diff --git a/vector-stores/spring-ai-mariadb-store/pom.xml b/vector-stores/spring-ai-mariadb-store/pom.xml index 849abad92bd..599842a0d99 100644 --- a/vector-stores/spring-ai-mariadb-store/pom.xml +++ b/vector-stores/spring-ai-mariadb-store/pom.xml @@ -48,11 +48,6 @@ HikariCP - - org.springframework - spring-jdbc - - org.springframework.boot spring-boot-starter-jdbc diff --git a/vector-stores/spring-ai-milvus-store/pom.xml b/vector-stores/spring-ai-milvus-store/pom.xml index e791a0d192e..33efbed1a77 100644 --- a/vector-stores/spring-ai-milvus-store/pom.xml +++ b/vector-stores/spring-ai-milvus-store/pom.xml @@ -70,7 +70,6 @@ test - org.springframework.boot spring-boot-starter-test diff --git a/vector-stores/spring-ai-opensearch-store/pom.xml b/vector-stores/spring-ai-opensearch-store/pom.xml index 163905a501e..8ac21ebe310 100644 --- a/vector-stores/spring-ai-opensearch-store/pom.xml +++ b/vector-stores/spring-ai-opensearch-store/pom.xml @@ -106,7 +106,7 @@ micrometer-observation-test test - + diff --git a/vector-stores/spring-ai-oracle-store/pom.xml b/vector-stores/spring-ai-oracle-store/pom.xml index b496677f6fe..d9c8e44f0e3 100644 --- a/vector-stores/spring-ai-oracle-store/pom.xml +++ b/vector-stores/spring-ai-oracle-store/pom.xml @@ -69,14 +69,13 @@ - org.springframework - spring-jdbc + org.springframework.boot + spring-boot-starter-jdbc - - org.springframework.boot - spring-boot-starter-jdbc + org.springframework + spring-jdbc diff --git a/vector-stores/spring-ai-typesense-store/pom.xml b/vector-stores/spring-ai-typesense-store/pom.xml index ef89d40bc43..4a2bcc3fbe3 100644 --- a/vector-stores/spring-ai-typesense-store/pom.xml +++ b/vector-stores/spring-ai-typesense-store/pom.xml @@ -88,7 +88,6 @@ test - io.micrometer micrometer-observation-test