Skip to content

Commit e459388

Browse files
ewilliams-clouderajkwatsonbaasitsharief
authored
misc. enhancements (#294)
* better startup script, maybe * Kill the java process if the startup script is killed. * 3000 tries for python * don't make everything need to work * wip for providing nonstreaming - still no tool calling non-streaming * fix non-streaming tool calling * wip on adding loading state for non-streaming * mypy * ui linting * add gpt-5-chat to azure * remove gpt-5-chat for azure * remove prints * Update dependencies and enhance tool calling functionality with non streaming support * ruff * mypy * switch to disabling streaming, rather than enabling * bump bedrock converse versions * mypy * bits of cleanup when reviewing the code * bump versions of docling, remove llama-index base dependency, and revert llama index bedrock version * bump version of llama-index, refactor streaming event, fix bug with default use streaming value --------- Co-authored-by: jwatson <jkwatson@gmail.com> Co-authored-by: Baasit Sharief <baasitsharief@gmail.com>
1 parent ebf1f4b commit e459388

File tree

22 files changed

+3086
-2718
lines changed

22 files changed

+3086
-2718
lines changed

backend/src/main/java/com/cloudera/cai/rag/Types.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,12 @@ public record RagDataSource(
103103
Long associatedSessionId) {}
104104

105105
@With
106+
@Builder
106107
public record QueryConfiguration(
107108
boolean enableHyde,
108109
boolean enableSummaryFilter,
109110
boolean enableToolCalling,
111+
Boolean disableStreaming,
110112
List<String> selectedTools) {}
111113

112114
@With

backend/src/main/java/com/cloudera/cai/rag/sessions/SessionRepository.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import com.cloudera.cai.rag.configuration.JdbiConfiguration;
4444
import com.cloudera.cai.util.exceptions.NotFound;
4545
import com.fasterxml.jackson.core.JsonProcessingException;
46+
import com.fasterxml.jackson.databind.DeserializationFeature;
4647
import com.fasterxml.jackson.databind.ObjectMapper;
4748
import java.time.Instant;
4849
import java.util.List;
@@ -57,9 +58,16 @@
5758
@Component
5859
public class SessionRepository {
5960
public static final Types.QueryConfiguration DEFAULT_QUERY_CONFIGURATION =
60-
new Types.QueryConfiguration(false, true, false, List.of());
61+
Types.QueryConfiguration.builder()
62+
.enableHyde(false)
63+
.enableSummaryFilter(true)
64+
.enableToolCalling(false)
65+
.disableStreaming(true)
66+
.selectedTools(List.of())
67+
.build();
6168
private final DatabaseOperations databaseOperations;
62-
private final ObjectMapper objectMapper = new ObjectMapper();
69+
private final ObjectMapper objectMapper =
70+
new ObjectMapper().disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES);
6371

6472
public SessionRepository(DatabaseOperations databaseOperations) {
6573
this.databaseOperations = databaseOperations;
@@ -177,6 +185,9 @@ private Types.QueryConfiguration extractQueryConfiguration(RowView rowView)
177185
if (queryConfiguration.selectedTools() == null) {
178186
queryConfiguration = queryConfiguration.withSelectedTools(List.of());
179187
}
188+
if (queryConfiguration.disableStreaming() == null) {
189+
queryConfiguration = queryConfiguration.withDisableStreaming(false);
190+
}
180191
return queryConfiguration;
181192
}
182193

backend/src/test/java/com/cloudera/cai/rag/TestData.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,12 @@ public static Types.Session createTestSessionInstance(
8383
null,
8484
"test-rerank-model",
8585
3,
86-
new Types.QueryConfiguration(false, true, true, List.of()));
86+
Types.QueryConfiguration.builder()
87+
.enableSummaryFilter(true)
88+
.enableToolCalling(true)
89+
.disableStreaming(true)
90+
.selectedTools(List.of())
91+
.build());
8792
}
8893

8994
public static Types.CreateSession createSessionInstance(String sessionName) {
@@ -99,7 +104,13 @@ public static Types.CreateSession createSessionInstance(
99104
null,
100105
"test-rerank-model",
101106
3,
102-
new Types.QueryConfiguration(false, true, true, List.of()),
107+
Types.QueryConfiguration.builder()
108+
.enableHyde(false)
109+
.enableSummaryFilter(true)
110+
.enableToolCalling(true)
111+
.disableStreaming(true)
112+
.selectedTools(List.of())
113+
.build(),
103114
projectId);
104115
}
105116

backend/src/test/java/com/cloudera/cai/rag/sessions/SessionControllerTest.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,14 @@ void update() {
137137
var updatedRerankModel = "new-rerank-model";
138138
var updatedProjectId = project.id();
139139

140+
var queryConfiguration =
141+
Types.QueryConfiguration.builder()
142+
.enableHyde(true)
143+
.enableSummaryFilter(false)
144+
.enableToolCalling(true)
145+
.disableStreaming(true)
146+
.selectedTools(List.of("foo"))
147+
.build();
140148
var updatedSession =
141149
sessionController.update(
142150
insertedSession
@@ -145,8 +153,7 @@ void update() {
145153
.withRerankModel(updatedRerankModel)
146154
.withName(updatedName)
147155
.withProjectId(updatedProjectId)
148-
.withQueryConfiguration(
149-
new Types.QueryConfiguration(true, false, true, List.of("foo"))),
156+
.withQueryConfiguration(queryConfiguration),
150157
request);
151158

152159
assertThat(updatedSession.id()).isNotNull();
@@ -160,8 +167,7 @@ void update() {
160167
assertThat(updatedSession.timeUpdated()).isAfter(insertedSession.timeUpdated());
161168
assertThat(updatedSession.createdById()).isEqualTo("test-user");
162169
assertThat(updatedSession.lastInteractionTime()).isNull();
163-
assertThat(updatedSession.queryConfiguration())
164-
.isEqualTo(new Types.QueryConfiguration(true, false, true, List.of("foo")));
170+
assertThat(updatedSession.queryConfiguration()).isEqualTo(queryConfiguration);
165171
}
166172

167173
@Test

llm-service/app/routers/index/sessions/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,18 @@ def generate_stream() -> Generator[str, None, None]:
301301

302302
first_message = True
303303
stream = future.result()
304+
305+
# If streaming is disabled, immediately send a loading event to show StreamedEvents
306+
if session.query_configuration.disable_streaming:
307+
loading = ChatEvent(
308+
type="thinking",
309+
name="thinking",
310+
timestamp=time.time(),
311+
data="Preparing full response...",
312+
)
313+
event_json = json.dumps({"event": loading.model_dump()})
314+
yield f"data: {event_json}\n\n"
315+
first_message = False
304316
for item in stream:
305317
response: ChatResponse = item
306318
# Check for cancellation between each response

llm-service/app/services/chat/streaming_chat.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def stream_chat(
8383
use_hyde=session.query_configuration.enable_hyde,
8484
use_summary_filter=session.query_configuration.enable_summary_filter,
8585
use_tool_calling=session.query_configuration.enable_tool_calling,
86+
use_streaming=not session.query_configuration.disable_streaming,
8687
)
8788

8889
response_id = str(uuid.uuid4())
@@ -188,14 +189,20 @@ def _stream_direct_llm_chat(
188189
user_name: Optional[str],
189190
) -> Generator[ChatResponse, None, None]:
190191
record_direct_llm_mlflow_run(response_id, session, user_name)
191-
192-
chat_response = llm_completion.stream_completion(
193-
session.id, query, session.inference_model
194-
)
195-
response: ChatResponse = ChatResponse(message=ChatMessage(content=query))
196-
for response in chat_response:
192+
response: ChatResponse
193+
if session.query_configuration.disable_streaming:
194+
# Use non-streaming completion when streaming is disabled
195+
response = llm_completion.completion(session.id, query, session.inference_model)
197196
response.additional_kwargs["response_id"] = response_id
198197
yield response
198+
else:
199+
chat_response = llm_completion.stream_completion(
200+
session.id, query, session.inference_model
201+
)
202+
response = ChatResponse(message=ChatMessage(content=query))
203+
for response in chat_response:
204+
response.additional_kwargs["response_id"] = response_id
205+
yield response
199206

200207
new_chat_message = RagStudioChatMessage(
201208
id=response_id,

llm-service/app/services/metadata_apis/session_metadata_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class SessionQueryConfiguration:
5252
enable_summary_filter: bool
5353
enable_tool_calling: bool = False
5454
selected_tools: list[str] = field(default_factory=list)
55+
disable_streaming: bool = False
5556

5657

5758
@dataclass
@@ -126,6 +127,9 @@ def session_from_java_response(data: dict[str, Any]) -> Session:
126127
enable_tool_calling=data["queryConfiguration"].get(
127128
"enableToolCalling", False
128129
),
130+
disable_streaming=data["queryConfiguration"].get(
131+
"disableStreaming", False
132+
),
129133
selected_tools=data["queryConfiguration"]["selectedTools"] or [],
130134
),
131135
associated_data_source_id=data.get("associatedDataSourceId", None),
@@ -146,6 +150,7 @@ def update_session(session: Session, user_name: Optional[str]) -> Session:
146150
"enableSummaryFilter": session.query_configuration.enable_summary_filter,
147151
"enableToolCalling": session.query_configuration.enable_tool_calling,
148152
"selectedTools": session.query_configuration.selected_tools,
153+
"disableStreaming": session.query_configuration.disable_streaming,
149154
},
150155
associatedDataSourceId=session.associated_data_source_id,
151156
)

llm-service/app/services/query/agents/tool_calling_querier.py

Lines changed: 61 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
FlexibleContextChatEngine,
7575
)
7676
from app.services.query.chat_events import ChatEvent
77+
from app.services.query.query_configuration import QueryConfiguration
7778

7879
if os.environ.get("ENABLE_OPIK") == "True":
7980
opik.configure(
@@ -199,6 +200,7 @@ def stream_chat(
199200
chat_messages: list[ChatMessage],
200201
session: Session,
201202
data_source_summaries: dict[int, str],
203+
configuration: QueryConfiguration,
202204
) -> StreamingAgentChatResponse:
203205
mcp_tools: list[BaseTool] = []
204206
if session.query_configuration and session.query_configuration.selected_tools:
@@ -221,7 +223,7 @@ def stream_chat(
221223
tools.insert(0, retrieval_tool)
222224

223225
gen, source_nodes = _run_streamer(
224-
chat_engine, chat_messages, enhanced_query, llm, tools
226+
chat_engine, chat_messages, enhanced_query, llm, tools, configuration
225227
)
226228

227229
return StreamingAgentChatResponse(chat_stream=gen, source_nodes=source_nodes)
@@ -233,9 +235,12 @@ def _run_streamer(
233235
enhanced_query: str,
234236
llm: FunctionCallingLLM,
235237
tools: list[BaseTool],
238+
configuration: QueryConfiguration,
236239
verbose: bool = True,
237240
) -> tuple[Generator[ChatResponse, None, None], list[NodeWithScore]]:
238-
agent, enhanced_query = build_function_agent(enhanced_query, llm, tools)
241+
agent, enhanced_query = build_function_agent(
242+
enhanced_query, llm, tools, configuration.use_streaming or False
243+
)
239244

240245
source_nodes: list[NodeWithScore] = []
241246

@@ -251,11 +256,22 @@ def _run_streamer(
251256
return chat_gen.chat_stream, chat_gen.source_nodes
252257

253258
# If no chat engine is provided, we can use the LLM directly
254-
direct_chat_gen = llm.stream_chat(
255-
messages=chat_messages
256-
+ [ChatMessage(role=MessageRole.USER, content=enhanced_query)]
257-
)
258-
return direct_chat_gen, source_nodes
259+
if configuration.use_streaming:
260+
direct_chat_gen = llm.stream_chat(
261+
messages=chat_messages
262+
+ [ChatMessage(role=MessageRole.USER, content=enhanced_query)]
263+
)
264+
return direct_chat_gen, source_nodes
265+
266+
# Use non-streaming LLM for direct chat when streaming is disabled
267+
def _fake_direct_stream() -> Generator[ChatResponse, None, None]:
268+
response = llm.chat(
269+
messages=chat_messages
270+
+ [ChatMessage(role=MessageRole.USER, content=enhanced_query)]
271+
)
272+
yield response
273+
274+
return _fake_direct_stream(), source_nodes
259275

260276
async def agen() -> AsyncGenerator[ChatResponse, None]:
261277
handler = agent.run(user_msg=enhanced_query, chat_history=chat_messages)
@@ -358,23 +374,33 @@ async def agen() -> AsyncGenerator[ChatResponse, None]:
358374
f"{str(event.response) if event.response else 'No content'}"
359375
)
360376
logger.info("========================")
361-
yield ChatResponse(
362-
message=ChatMessage(
363-
role=MessageRole.TOOL,
364-
content=(
365-
event.response.content if event.response.content else ""
377+
if configuration.use_streaming:
378+
yield ChatResponse(
379+
message=ChatMessage(
380+
role=(MessageRole.TOOL),
381+
content=event.response.content,
366382
),
367-
),
368-
delta="",
369-
raw=event.raw,
370-
additional_kwargs={
371-
"chat_event": ChatEvent(
372-
type="agent_response",
373-
name=event.current_agent_name,
374-
data=data,
383+
delta="",
384+
raw=event.raw,
385+
additional_kwargs=(
386+
{
387+
"chat_event": ChatEvent(
388+
type="agent_response",
389+
name=event.current_agent_name,
390+
data=data,
391+
),
392+
}
375393
),
376-
},
377-
)
394+
)
395+
else:
396+
yield ChatResponse(
397+
message=ChatMessage(
398+
role=(MessageRole.ASSISTANT),
399+
content=event.response.content,
400+
),
401+
delta=(event.response.content),
402+
raw=event.raw,
403+
)
378404
elif isinstance(event, AgentStream):
379405
if len(event.tool_calls) > 0:
380406
continue
@@ -436,15 +462,22 @@ def gen() -> Generator[ChatResponse, None, None]:
436462

437463

438464
def build_function_agent(
439-
enhanced_query: str, llm: FunctionCallingLLM, tools: list[BaseTool]
465+
enhanced_query: str,
466+
llm: FunctionCallingLLM,
467+
tools: list[BaseTool],
468+
streaming_enabled: bool,
440469
) -> tuple[FunctionAgent, str]:
441470
formatted_prompt = DEFAULT_AGENT_PROMPT.format(
442471
date=datetime.datetime.now().strftime("%A, %B %d, %Y"),
443472
time=datetime.datetime.now().strftime("%H:%M:%S %p"),
444473
)
445474
callable_tools = cast(list[BaseTool | Callable[[], Any]], tools)
446475
if llm.metadata.model_name in NON_SYSTEM_MESSAGE_MODELS:
447-
agent = FunctionAgent(tools=callable_tools, llm=llm)
476+
agent = FunctionAgent(
477+
tools=callable_tools,
478+
llm=llm,
479+
streaming=streaming_enabled,
480+
)
448481
enhanced_query = (
449482
"ROLE DESCRIPTION =========================================\n"
450483
+ formatted_prompt
@@ -460,7 +493,10 @@ def build_function_agent(
460493
):
461494
llm = FakeStreamBedrockConverse.from_bedrock_converse(llm)
462495
agent = FunctionAgent(
463-
tools=callable_tools, llm=llm, system_prompt=formatted_prompt
496+
tools=callable_tools,
497+
llm=llm,
498+
system_prompt=formatted_prompt,
499+
streaming=streaming_enabled,
464500
)
465501

466502
return agent, enhanced_query

0 commit comments

Comments
 (0)