Skip to content

Commit e8cb1cb

Browse files
committed
feat: Add asynchronous tool calling support
Add optional async execution for tool callbacks to improve performance in I/O-intensive and high-concurrency scenarios. Key changes: - Add AsyncToolCallback interface extending ToolCallback - Add executeToolCallsAsync() method to ToolCallingManager - Update all 11 ChatModel implementations - Add 15 new async-specific tests - Full backward compatibility maintained Closes #4755 Signed-off-by: shaojie <741047428@qq.com>
1 parent 50db344 commit e8cb1cb

File tree

19 files changed

+1384
-330
lines changed

19 files changed

+1384
-330
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import org.slf4j.LoggerFactory;
3333
import reactor.core.publisher.Flux;
3434
import reactor.core.publisher.Mono;
35-
import reactor.core.scheduler.Schedulers;
3635

3736
import org.springframework.ai.anthropic.api.AnthropicApi;
3837
import org.springframework.ai.anthropic.api.AnthropicApi.AnthropicMessage;
@@ -268,42 +267,37 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
268267
Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
269268
ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage);
270269

271-
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) {
272-
273-
if (chatResponse.hasFinishReasons(Set.of("tool_use"))) {
274-
// FIXME: bounded elastic needs to be used since tool calling
275-
// is currently only synchronous
276-
return Flux.deferContextual(ctx -> {
277-
// TODO: factor out the tool execution logic with setting context into a utility.
278-
ToolExecutionResult toolExecutionResult;
279-
try {
280-
ToolCallReactiveContextHolder.setContext(ctx);
281-
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
282-
}
283-
finally {
284-
ToolCallReactiveContextHolder.clearContext();
285-
}
286-
if (toolExecutionResult.returnDirect()) {
287-
// Return tool execution result directly to the client.
288-
return Flux.just(ChatResponse.builder().from(chatResponse)
289-
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
290-
.build());
291-
}
292-
else {
293-
// Send the tool execution result back to the model.
294-
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
295-
chatResponse);
296-
}
297-
}).subscribeOn(Schedulers.boundedElastic());
298-
}
299-
else {
300-
return Mono.empty();
301-
}
270+
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) {
271+
272+
if (chatResponse.hasFinishReasons(Set.of("tool_use"))) {
273+
return Flux.deferContextual(ctx -> {
274+
// TODO: factor out the tool execution logic with setting context into a utility.
275+
ToolCallReactiveContextHolder.setContext(ctx);
276+
return this.toolCallingManager.executeToolCallsAsync(prompt, chatResponse)
277+
.doFinally(s -> ToolCallReactiveContextHolder.clearContext())
278+
.flatMapMany(toolExecutionResult -> {
279+
if (toolExecutionResult.returnDirect()) {
280+
// Return tool execution result directly to the client.
281+
return Flux.just(ChatResponse.builder().from(chatResponse)
282+
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
283+
.build());
284+
}
285+
else {
286+
// Send the tool execution result back to the model.
287+
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
288+
chatResponse);
289+
}
290+
});
291+
});
302292
}
303293
else {
304-
// If internal tool execution is not required, just return the chat response.
305-
return Mono.just(chatResponse);
294+
return Mono.empty();
306295
}
296+
}
297+
else {
298+
// If internal tool execution is not required, just return the chat response.
299+
return Mono.just(chatResponse);
300+
}
307301
})
308302
.doOnError(observation::error)
309303
.doFinally(s -> observation.stop())

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@
6363
import org.slf4j.Logger;
6464
import org.slf4j.LoggerFactory;
6565
import reactor.core.publisher.Flux;
66-
import reactor.core.scheduler.Schedulers;
6766

6867
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.JsonSchema;
6968
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.Type;
@@ -379,31 +378,27 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
379378

380379
return chatResponseFlux.flatMapSequential(chatResponse -> {
381380
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) {
382-
// FIXME: bounded elastic needs to be used since tool calling
383-
// is currently only synchronous
384381
return Flux.deferContextual(ctx -> {
385-
ToolExecutionResult toolExecutionResult;
386-
try {
387-
ToolCallReactiveContextHolder.setContext(ctx);
388-
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
389-
}
390-
finally {
391-
ToolCallReactiveContextHolder.clearContext();
392-
}
393-
if (toolExecutionResult.returnDirect()) {
394-
// Return tool execution result directly to the client.
395-
return Flux.just(ChatResponse.builder()
396-
.from(chatResponse)
397-
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
398-
.build());
399-
}
400-
else {
401-
// Send the tool execution result back to the model.
402-
return this.internalStream(
403-
new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
404-
chatResponse);
405-
}
406-
}).subscribeOn(Schedulers.boundedElastic());
382+
ToolCallReactiveContextHolder.setContext(ctx);
383+
return this.toolCallingManager.executeToolCallsAsync(prompt, chatResponse)
384+
.doFinally(s -> ToolCallReactiveContextHolder.clearContext())
385+
.flatMapMany(toolExecutionResult -> {
386+
if (toolExecutionResult.returnDirect()) {
387+
// Return tool execution result directly to the
388+
// client.
389+
return Flux.just(ChatResponse.builder()
390+
.from(chatResponse)
391+
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
392+
.build());
393+
}
394+
else {
395+
// Send the tool execution result back to the model.
396+
return this.internalStream(
397+
new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
398+
chatResponse);
399+
}
400+
});
401+
});
407402
}
408403

409404
Flux<ChatResponse> flux = Flux.just(chatResponse)

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import org.slf4j.Logger;
3636
import org.slf4j.LoggerFactory;
3737
import reactor.core.publisher.Flux;
38-
import reactor.core.scheduler.Schedulers;
3938
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
4039
import software.amazon.awssdk.core.SdkBytes;
4140
import software.amazon.awssdk.core.document.Document;
@@ -805,32 +804,27 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousCh
805804
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)
806805
&& chatResponse.hasFinishReasons(Set.of(StopReason.TOOL_USE.toString()))) {
807806

808-
// FIXME: bounded elastic needs to be used since tool calling
809-
// is currently only synchronous
810807
return Flux.deferContextual(ctx -> {
811-
ToolExecutionResult toolExecutionResult;
812-
try {
813-
ToolCallReactiveContextHolder.setContext(ctx);
814-
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
815-
}
816-
finally {
817-
ToolCallReactiveContextHolder.clearContext();
818-
}
819-
820-
if (toolExecutionResult.returnDirect()) {
821-
// Return tool execution result directly to the client.
822-
return Flux.just(ChatResponse.builder()
823-
.from(chatResponse)
824-
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
825-
.build());
826-
}
827-
else {
828-
// Send the tool execution result back to the model.
829-
return this.internalStream(
830-
new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
831-
chatResponse);
832-
}
833-
}).subscribeOn(Schedulers.boundedElastic());
808+
ToolCallReactiveContextHolder.setContext(ctx);
809+
return this.toolCallingManager.executeToolCallsAsync(prompt, chatResponse)
810+
.doFinally(s -> ToolCallReactiveContextHolder.clearContext())
811+
.flatMapMany(toolExecutionResult -> {
812+
if (toolExecutionResult.returnDirect()) {
813+
// Return tool execution result directly to the
814+
// client.
815+
return Flux.just(ChatResponse.builder()
816+
.from(chatResponse)
817+
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
818+
.build());
819+
}
820+
else {
821+
// Send the tool execution result back to the model.
822+
return this.internalStream(
823+
new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
824+
chatResponse);
825+
}
826+
});
827+
});
834828
}
835829
else {
836830
return Flux.just(chatResponse);

models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import org.slf4j.LoggerFactory;
2828
import reactor.core.publisher.Flux;
2929
import reactor.core.publisher.Mono;
30-
import reactor.core.scheduler.Schedulers;
3130

3231
import org.springframework.ai.chat.messages.AssistantMessage;
3332
import org.springframework.ai.chat.messages.MessageType;
@@ -285,36 +284,31 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
285284
}));
286285

287286
// @formatter:off
288-
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
289-
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
290-
// FIXME: bounded elastic needs to be used since tool calling
291-
// is currently only synchronous
292-
return Flux.deferContextual(ctx -> {
293-
ToolExecutionResult toolExecutionResult;
294-
try {
295-
ToolCallReactiveContextHolder.setContext(ctx);
296-
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
297-
}
298-
finally {
299-
ToolCallReactiveContextHolder.clearContext();
300-
}
301-
if (toolExecutionResult.returnDirect()) {
302-
// Return tool execution result directly to the client.
303-
return Flux.just(ChatResponse.builder().from(response)
304-
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
305-
.build());
306-
}
307-
else {
308-
// Send the tool execution result back to the model.
309-
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
310-
response);
311-
}
312-
}).subscribeOn(Schedulers.boundedElastic());
313-
}
314-
else {
315-
return Flux.just(response);
316-
}
317-
})
287+
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
288+
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
289+
return Flux.deferContextual(ctx -> {
290+
ToolCallReactiveContextHolder.setContext(ctx);
291+
return this.toolCallingManager.executeToolCallsAsync(prompt, response)
292+
.doFinally(s -> ToolCallReactiveContextHolder.clearContext())
293+
.flatMapMany(toolExecutionResult -> {
294+
if (toolExecutionResult.returnDirect()) {
295+
// Return tool execution result directly to the client.
296+
return Flux.just(ChatResponse.builder().from(response)
297+
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
298+
.build());
299+
}
300+
else {
301+
// Send the tool execution result back to the model.
302+
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
303+
response);
304+
}
305+
});
306+
});
307+
}
308+
else {
309+
return Flux.just(response);
310+
}
311+
})
318312
.doOnError(observation::error)
319313
.doFinally(s -> observation.stop())
320314
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));

models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
import org.slf4j.Logger;
4848
import org.slf4j.LoggerFactory;
4949
import reactor.core.publisher.Flux;
50-
import reactor.core.scheduler.Schedulers;
5150

5251
import org.springframework.ai.chat.messages.AssistantMessage;
5352
import org.springframework.ai.chat.messages.Message;
@@ -548,35 +547,30 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
548547
});
549548

550549
// @formatter:off
551-
Flux<ChatResponse> flux = chatResponseFlux.flatMap(response -> {
552-
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
553-
// FIXME: bounded elastic needs to be used since tool calling
554-
// is currently only synchronous
555-
return Flux.deferContextual(ctx -> {
556-
ToolExecutionResult toolExecutionResult;
557-
try {
558-
ToolCallReactiveContextHolder.setContext(ctx);
559-
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
560-
}
561-
finally {
562-
ToolCallReactiveContextHolder.clearContext();
563-
}
564-
if (toolExecutionResult.returnDirect()) {
565-
// Return tool execution result directly to the client.
566-
return Flux.just(ChatResponse.builder().from(response)
567-
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
568-
.build());
569-
}
570-
else {
571-
// Send the tool execution result back to the model.
572-
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), response);
573-
}
574-
}).subscribeOn(Schedulers.boundedElastic());
575-
}
576-
else {
577-
return Flux.just(response);
578-
}
579-
})
550+
Flux<ChatResponse> flux = chatResponseFlux.flatMap(response -> {
551+
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
552+
return Flux.deferContextual(ctx -> {
553+
ToolCallReactiveContextHolder.setContext(ctx);
554+
return this.toolCallingManager.executeToolCallsAsync(prompt, response)
555+
.doFinally(s -> ToolCallReactiveContextHolder.clearContext())
556+
.flatMapMany(toolExecutionResult -> {
557+
if (toolExecutionResult.returnDirect()) {
558+
// Return tool execution result directly to the client.
559+
return Flux.just(ChatResponse.builder().from(response)
560+
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
561+
.build());
562+
}
563+
else {
564+
// Send the tool execution result back to the model.
565+
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), response);
566+
}
567+
});
568+
});
569+
}
570+
else {
571+
return Flux.just(response);
572+
}
573+
})
580574
.doOnError(observation::error)
581575
.doFinally(s -> observation.stop())
582576
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));

models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/schema/GoogleGenAiToolCallingManager.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.List;
2020

2121
import com.fasterxml.jackson.databind.node.ObjectNode;
22+
import reactor.core.publisher.Mono;
2223

2324
import org.springframework.ai.chat.model.ChatResponse;
2425
import org.springframework.ai.chat.prompt.Prompt;
@@ -96,4 +97,17 @@ public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResp
9697
return this.delegateToolCallingManager.executeToolCalls(prompt, chatResponse);
9798
}
9899

100+
/**
101+
* Executes tool calls asynchronously by delegating to the underlying tool calling
102+
* manager.
103+
* @param prompt the original prompt that triggered the tool calls
104+
* @param chatResponse the chat response containing the tool calls to execute
105+
* @return a Mono that emits the result of executing the tool calls
106+
* @since 1.2.0
107+
*/
108+
@Override
109+
public Mono<ToolExecutionResult> executeToolCallsAsync(Prompt prompt, ChatResponse chatResponse) {
110+
return this.delegateToolCallingManager.executeToolCallsAsync(prompt, chatResponse);
111+
}
112+
99113
}

0 commit comments

Comments
 (0)