Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import dev.langchain4j.agentic.agent.AgentRequest;
import dev.langchain4j.agentic.agent.AgentResponse;
import dev.langchain4j.agentic.declarative.TypedKey;
import dev.langchain4j.agentic.internal.AgentExecutor;
import dev.langchain4j.agentic.scope.AgenticScope;
import dev.langchain4j.agentic.workflow.ConditionalAgentService;
Expand Down Expand Up @@ -48,6 +49,11 @@ public ConditionalAgentService<T> subAgents(List<AgentExecutor> agentExecutors)
return this.subAgents(agentExecutors.toArray());
}

@Override
public ConditionalAgentService<T> outputKey(Class<? extends TypedKey<?>> outputKey) {
throw new UnsupportedOperationException("Feature not implemented yet");
}

@Override
public ConditionalAgentService<T> beforeAgentInvocation(Consumer<AgentRequest> consumer) {
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import dev.langchain4j.agentic.agent.AgentRequest;
import dev.langchain4j.agentic.agent.AgentResponse;
import dev.langchain4j.agentic.declarative.TypedKey;
import dev.langchain4j.agentic.internal.AgentExecutor;
import dev.langchain4j.agentic.scope.AgenticScope;
import dev.langchain4j.agentic.workflow.LoopAgentService;
Expand Down Expand Up @@ -77,6 +78,11 @@ public LoopAgentService<T> subAgents(List<AgentExecutor> agentExecutors) {
return this;
}

@Override
public LoopAgentService<T> outputKey(Class<? extends TypedKey<?>> outputKey) {
throw new UnsupportedOperationException("Feature not implemented yet");
}

@Override
public LoopAgentService<T> beforeAgentInvocation(Consumer<AgentRequest> consumer) {
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import dev.langchain4j.agentic.agent.AgentRequest;
import dev.langchain4j.agentic.agent.AgentResponse;
import dev.langchain4j.agentic.declarative.TypedKey;
import dev.langchain4j.agentic.internal.AgentExecutor;
import dev.langchain4j.agentic.workflow.ParallelAgentService;
import io.serverlessworkflow.impl.ExecutorServiceHolder;
Expand Down Expand Up @@ -47,6 +48,11 @@ public ParallelAgentService<T> subAgents(List<AgentExecutor> agentExecutors) {
return this.subAgents(agentExecutors.toArray());
}

@Override
public ParallelAgentService<T> outputKey(Class<? extends TypedKey<?>> outputKey) {
throw new UnsupportedOperationException("Feature not implemented yet");
}

@Override
public ParallelAgentService<T> beforeAgentInvocation(Consumer<AgentRequest> consumer) {
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import dev.langchain4j.agentic.agent.AgentRequest;
import dev.langchain4j.agentic.agent.AgentResponse;
import dev.langchain4j.agentic.declarative.TypedKey;
import dev.langchain4j.agentic.internal.AgentExecutor;
import dev.langchain4j.agentic.workflow.SequentialAgentService;
import java.util.List;
Expand Down Expand Up @@ -46,6 +47,11 @@ public SequentialAgentService<T> subAgents(List<AgentExecutor> agentExecutors) {
return this;
}

@Override
public SequentialAgentService<T> outputKey(Class<? extends TypedKey<?>> outputKey) {
throw new UnsupportedOperationException("Feature not implemented yet");
}

@Override
public SequentialAgentService<T> beforeAgentInvocation(Consumer<AgentRequest> consumer) {
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static dev.langchain4j.agentic.internal.AgentUtil.agentsToExecutors;

import dev.langchain4j.agentic.internal.AgentExecutor;
import dev.langchain4j.agentic.scope.AgentInvocationListener;
import dev.langchain4j.agentic.scope.AgenticScope;
import dev.langchain4j.agentic.scope.DefaultAgenticScope;
import io.serverlessworkflow.api.types.func.LoopPredicateIndex;
Expand All @@ -35,7 +36,7 @@ public static List<AgentExecutor> toExecutors(Object... agents) {
}

public static Function<DefaultAgenticScope, Object> toFunction(AgentExecutor exec) {
return exec::execute;
return agenticScope -> exec.execute(agenticScope, AgentInvocationListener.NO_OP);
}

public static LoopPredicateIndex<AgenticScope, Object> toWhile(Predicate<AgenticScope> exit) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,76 @@ public interface NovelCreator {
String createNovel(
@V("topic") String topic, @V("audience") String audience, @V("style") String style);
}

public static Agents.StorySeedAgent newStorySeedAgent() {
return spy(
AgenticServices.agentBuilder(Agents.StorySeedAgent.class)
.chatModel(BASE_MODEL)
.outputKey("response")
.build());
}

public static Agents.PlotAgent newPlotAgent() {
return spy(
AgenticServices.agentBuilder(Agents.PlotAgent.class)
.chatModel(BASE_MODEL)
.outputKey("response")
.build());
}

public static Agents.SceneAgent newSceneAgent() {
return spy(
AgenticServices.agentBuilder(Agents.SceneAgent.class)
.chatModel(BASE_MODEL)
.outputKey("response")
.build());
}

public static Agents.SettingAgent newSettingAgent() {
return spy(
AgenticServices.agentBuilder(Agents.SettingAgent.class)
.chatModel(BASE_MODEL)
.outputKey("response")
.build());
}

public static Agents.HeroAgent newHeroAgent() {
return spy(
AgenticServices.agentBuilder(Agents.HeroAgent.class)
.chatModel(BASE_MODEL)
.outputKey("response")
.build());
}

public static Agents.ConflictAgent newConflictAgent() {
return spy(
AgenticServices.agentBuilder(Agents.ConflictAgent.class)
.chatModel(BASE_MODEL)
.outputKey("response")
.build());
}

public static Agents.FactAgent newFactAgent() {
return spy(
AgenticServices.agentBuilder(Agents.FactAgent.class)
.chatModel(BASE_MODEL)
.outputKey("response")
.build());
}

public static Agents.CultureAgent newCultureAgent() {
return spy(
AgenticServices.agentBuilder(Agents.CultureAgent.class)
.chatModel(BASE_MODEL)
.outputKey("response")
.build());
}

public static Agents.TechnologyAgent newTechnologyAgent() {
return spy(
AgenticServices.agentBuilder(Agents.TechnologyAgent.class)
.chatModel(BASE_MODEL)
.outputKey("response")
.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,18 @@
package io.serverlessworkflow.fluent.agentic;

import static io.serverlessworkflow.fluent.agentic.Agents.*;
import static io.serverlessworkflow.fluent.agentic.AgentsUtils.newConflictAgent;
import static io.serverlessworkflow.fluent.agentic.AgentsUtils.newCultureAgent;
import static io.serverlessworkflow.fluent.agentic.AgentsUtils.newFactAgent;
import static io.serverlessworkflow.fluent.agentic.AgentsUtils.newHeroAgent;
import static io.serverlessworkflow.fluent.agentic.AgentsUtils.newPlotAgent;
import static io.serverlessworkflow.fluent.agentic.AgentsUtils.newSceneAgent;
import static io.serverlessworkflow.fluent.agentic.AgentsUtils.newSettingAgent;
import static io.serverlessworkflow.fluent.agentic.AgentsUtils.newStorySeedAgent;
import static io.serverlessworkflow.fluent.agentic.AgentsUtils.newTechnologyAgent;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand All @@ -38,9 +48,12 @@ class WorkflowTests {

@Test
public void testAgent() throws ExecutionException, InterruptedException {
final StorySeedAgent storySeedAgent = mock(StorySeedAgent.class);
final StorySeedAgent storySeedAgent = newStorySeedAgent();

doReturn("storySeedAgent")
.when(storySeedAgent)
.invoke(org.mockito.ArgumentMatchers.anyString());

when(storySeedAgent.invoke(eq("A Great Story"))).thenReturn("storySeedAgent");
when(storySeedAgent.outputKey()).thenReturn("premise");
when(storySeedAgent.name()).thenReturn("storySeedAgent");

Expand All @@ -67,21 +80,23 @@ public void testAgent() throws ExecutionException, InterruptedException {

@Test
public void testAgents() throws ExecutionException, InterruptedException {
final StorySeedAgent storySeedAgent = mock(StorySeedAgent.class);
final PlotAgent plotAgent = mock(PlotAgent.class);
final SceneAgent sceneAgent = mock(SceneAgent.class);
final StorySeedAgent storySeedAgent = newStorySeedAgent();
final PlotAgent plotAgent = newPlotAgent();
final SceneAgent sceneAgent = newSceneAgent();

when(storySeedAgent.invoke(eq("A Great Story"))).thenReturn("storySeedAgent");
when(storySeedAgent.outputKey()).thenReturn("premise");
when(storySeedAgent.name()).thenReturn("storySeedAgent");
doReturn("storySeedAgent")
.when(storySeedAgent)
.invoke(org.mockito.ArgumentMatchers.anyString());

when(plotAgent.invoke(eq("storySeedAgent"))).thenReturn("plotAgent");
when(plotAgent.outputKey()).thenReturn("plot");
when(plotAgent.name()).thenReturn("plotAgent");
doReturn("plotAgent").when(plotAgent).invoke(org.mockito.ArgumentMatchers.anyString());

when(sceneAgent.invoke(eq("plotAgent"))).thenReturn("sceneAgent");
when(sceneAgent.outputKey()).thenReturn("story");
when(sceneAgent.name()).thenReturn("sceneAgent");
doReturn("sceneAgent").when(sceneAgent).invoke(org.mockito.ArgumentMatchers.anyString());

Workflow workflow =
AgentWorkflowBuilder.workflow("storyFlow")
Expand Down Expand Up @@ -110,21 +125,25 @@ public void testAgents() throws ExecutionException, InterruptedException {

@Test
public void testSequence() throws ExecutionException, InterruptedException {
final StorySeedAgent storySeedAgent = mock(StorySeedAgent.class);
final PlotAgent plotAgent = mock(PlotAgent.class);
final SceneAgent sceneAgent = mock(SceneAgent.class);
final StorySeedAgent storySeedAgent = newStorySeedAgent();
final PlotAgent plotAgent = newPlotAgent();
final SceneAgent sceneAgent = newSceneAgent();

when(storySeedAgent.invoke(eq("A Great Story"))).thenReturn("storySeedAgent");
when(storySeedAgent.outputKey()).thenReturn("premise");
when(storySeedAgent.name()).thenReturn("storySeedAgent");

when(plotAgent.invoke(eq("storySeedAgent"))).thenReturn("plotAgent");
doReturn("storySeedAgent")
.when(storySeedAgent)
.invoke(org.mockito.ArgumentMatchers.anyString());

when(plotAgent.outputKey()).thenReturn("plot");
when(plotAgent.name()).thenReturn("plotAgent");

when(sceneAgent.invoke(eq("plotAgent"))).thenReturn("sceneAgent");
doReturn("plotAgent").when(plotAgent).invoke(org.mockito.ArgumentMatchers.anyString());

when(sceneAgent.outputKey()).thenReturn("story");
when(sceneAgent.name()).thenReturn("sceneAgent");
doReturn("sceneAgent").when(sceneAgent).invoke(org.mockito.ArgumentMatchers.anyString());

Workflow workflow =
AgentWorkflowBuilder.workflow("storyFlow")
Expand All @@ -149,22 +168,25 @@ public void testSequence() throws ExecutionException, InterruptedException {

@Test
public void testParallel() throws ExecutionException, InterruptedException {
final SettingAgent setting = newSettingAgent();
final HeroAgent hero = newHeroAgent();
final ConflictAgent conflict = newConflictAgent();

final SettingAgent setting = mock(SettingAgent.class);
final HeroAgent hero = mock(HeroAgent.class);
final ConflictAgent conflict = mock(ConflictAgent.class);

when(setting.invoke(eq("sci-fi"))).thenReturn("Fake conflict response");
when(setting.outputKey()).thenReturn("setting");
when(setting.name()).thenReturn("setting");
doReturn("Fake setting response")
.when(setting)
.invoke(org.mockito.ArgumentMatchers.anyString());

when(hero.invoke(eq("sci-fi"))).thenReturn("Fake hero response");
when(hero.outputKey()).thenReturn("hero");
when(hero.name()).thenReturn("hero");
doReturn("Fake hero response").when(hero).invoke(org.mockito.ArgumentMatchers.anyString());

when(conflict.invoke(eq("sci-fi"))).thenReturn("Fake setting response");
when(conflict.outputKey()).thenReturn("conflict");
when(conflict.name()).thenReturn("conflict");
doReturn("Fake conflict response")
.when(conflict)
.invoke(org.mockito.ArgumentMatchers.anyString());

Workflow workflow =
AgentWorkflowBuilder.workflow("parallelFlow")
Expand All @@ -178,9 +200,9 @@ public void testParallel() throws ExecutionException, InterruptedException {
Map<String, Object> result =
app.workflowDefinition(workflow).instance(topic).start().get().asMap().orElseThrow();

assertEquals("Fake conflict response", result.get("setting").toString());
assertEquals("Fake setting response", result.get("setting").toString());
assertEquals("Fake hero response", result.get("hero").toString());
assertEquals("Fake setting response", result.get("conflict").toString());
assertEquals("Fake conflict response", result.get("conflict").toString());
}

try (WorkflowApplication app = WorkflowApplication.builder().build()) {
Expand All @@ -192,35 +214,40 @@ public void testParallel() throws ExecutionException, InterruptedException {
.as(AgenticScope.class)
.orElseThrow();

assertEquals("Fake conflict response", result.readState("setting").toString());
assertEquals("Fake setting response", result.readState("setting").toString());
assertEquals("Fake hero response", result.readState("hero").toString());
assertEquals("Fake setting response", result.readState("conflict").toString());
assertEquals("Fake conflict response", result.readState("conflict").toString());
}
}

@Test
public void testSeqAndThenParallel() throws ExecutionException, InterruptedException {
final FactAgent factAgent = mock(FactAgent.class);
final CultureAgent cultureAgent = mock(CultureAgent.class);
final TechnologyAgent technologyAgent = mock(TechnologyAgent.class);
final FactAgent factAgent = newFactAgent();
final CultureAgent cultureAgent = newCultureAgent();
final TechnologyAgent technologyAgent = newTechnologyAgent();

List<String> cultureTraits =
List.of("Alien Culture Trait 1", "Alien Culture Trait 2", "Alien Culture Trait 3");

List<String> technologyTraits =
List.of("Alien Technology Trait 1", "Alien Technology Trait 2", "Alien Technology Trait 3");

when(factAgent.invoke(eq("alien"))).thenReturn("Some Fact about aliens");
when(factAgent.outputKey()).thenReturn("fact");
when(factAgent.name()).thenReturn("fact");
doReturn("Some Fact about aliens")
.when(factAgent)
.invoke(org.mockito.ArgumentMatchers.anyString());

when(cultureAgent.invoke(eq("Some Fact about aliens"))).thenReturn(cultureTraits);
when(cultureAgent.outputKey()).thenReturn("culture");
when(cultureAgent.name()).thenReturn("culture");
doReturn(cultureTraits).when(cultureAgent).invoke(org.mockito.ArgumentMatchers.anyString());

when(technologyAgent.invoke(eq("Some Fact about aliens"))).thenReturn(technologyTraits);
when(technologyAgent.outputKey()).thenReturn("technology");
when(technologyAgent.name()).thenReturn("technology");
doReturn(technologyTraits)
.when(technologyAgent)
.invoke(org.mockito.ArgumentMatchers.anyString());

Workflow workflow =
AgentWorkflowBuilder.workflow("alienCultureFlow")
.tasks(
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
<version.org.hibernate.validator>9.1.0.Final</version.org.hibernate.validator>
<version.org.glassfish.expressly>6.0.0</version.org.glassfish.expressly>
<!-- Experimental modules from langchain4j -->
<version.dev.langchain4j.beta>1.8.0-beta15</version.dev.langchain4j.beta>
<version.dev.langchain4j.beta>1.9.0-beta16</version.dev.langchain4j.beta>
<!-- Base langchain4j version -->
<version.dev.langchain4j>1.8.0</version.dev.langchain4j>
<!-- Swagger Parser -->
Expand Down