Skip to content

Commit 948ab21

Browse files
Merge pull request #311 from cloudera/main
Release 1.29.0
2 parents 10c72d3 + 8d58bc5 commit 948ab21

File tree

34 files changed

+3275
-2974
lines changed

34 files changed

+3275
-2974
lines changed

backend/src/main/java/com/cloudera/cai/rag/Types.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,12 @@ public record RagDataSource(
103103
Long associatedSessionId) {}
104104

105105
@With
106+
@Builder
106107
public record QueryConfiguration(
107108
boolean enableHyde,
108109
boolean enableSummaryFilter,
109110
boolean enableToolCalling,
111+
Boolean disableStreaming,
110112
List<String> selectedTools) {}
111113

112114
@With

backend/src/main/java/com/cloudera/cai/rag/sessions/SessionRepository.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import com.cloudera.cai.rag.configuration.JdbiConfiguration;
4444
import com.cloudera.cai.util.exceptions.NotFound;
4545
import com.fasterxml.jackson.core.JsonProcessingException;
46+
import com.fasterxml.jackson.databind.DeserializationFeature;
4647
import com.fasterxml.jackson.databind.ObjectMapper;
4748
import java.time.Instant;
4849
import java.util.List;
@@ -57,9 +58,16 @@
5758
@Component
5859
public class SessionRepository {
5960
public static final Types.QueryConfiguration DEFAULT_QUERY_CONFIGURATION =
60-
new Types.QueryConfiguration(false, true, false, List.of());
61+
Types.QueryConfiguration.builder()
62+
.enableHyde(false)
63+
.enableSummaryFilter(true)
64+
.enableToolCalling(false)
65+
.disableStreaming(true)
66+
.selectedTools(List.of())
67+
.build();
6168
private final DatabaseOperations databaseOperations;
62-
private final ObjectMapper objectMapper = new ObjectMapper();
69+
private final ObjectMapper objectMapper =
70+
new ObjectMapper().disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES);
6371

6472
public SessionRepository(DatabaseOperations databaseOperations) {
6573
this.databaseOperations = databaseOperations;
@@ -177,6 +185,9 @@ private Types.QueryConfiguration extractQueryConfiguration(RowView rowView)
177185
if (queryConfiguration.selectedTools() == null) {
178186
queryConfiguration = queryConfiguration.withSelectedTools(List.of());
179187
}
188+
if (queryConfiguration.disableStreaming() == null) {
189+
queryConfiguration = queryConfiguration.withDisableStreaming(false);
190+
}
180191
return queryConfiguration;
181192
}
182193

backend/src/test/java/com/cloudera/cai/rag/TestData.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,12 @@ public static Types.Session createTestSessionInstance(
8383
null,
8484
"test-rerank-model",
8585
3,
86-
new Types.QueryConfiguration(false, true, true, List.of()));
86+
Types.QueryConfiguration.builder()
87+
.enableSummaryFilter(true)
88+
.enableToolCalling(true)
89+
.disableStreaming(true)
90+
.selectedTools(List.of())
91+
.build());
8792
}
8893

8994
public static Types.CreateSession createSessionInstance(String sessionName) {
@@ -99,7 +104,13 @@ public static Types.CreateSession createSessionInstance(
99104
null,
100105
"test-rerank-model",
101106
3,
102-
new Types.QueryConfiguration(false, true, true, List.of()),
107+
Types.QueryConfiguration.builder()
108+
.enableHyde(false)
109+
.enableSummaryFilter(true)
110+
.enableToolCalling(true)
111+
.disableStreaming(true)
112+
.selectedTools(List.of())
113+
.build(),
103114
projectId);
104115
}
105116

backend/src/test/java/com/cloudera/cai/rag/sessions/SessionControllerTest.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,14 @@ void update() {
137137
var updatedRerankModel = "new-rerank-model";
138138
var updatedProjectId = project.id();
139139

140+
var queryConfiguration =
141+
Types.QueryConfiguration.builder()
142+
.enableHyde(true)
143+
.enableSummaryFilter(false)
144+
.enableToolCalling(true)
145+
.disableStreaming(true)
146+
.selectedTools(List.of("foo"))
147+
.build();
140148
var updatedSession =
141149
sessionController.update(
142150
insertedSession
@@ -145,8 +153,7 @@ void update() {
145153
.withRerankModel(updatedRerankModel)
146154
.withName(updatedName)
147155
.withProjectId(updatedProjectId)
148-
.withQueryConfiguration(
149-
new Types.QueryConfiguration(true, false, true, List.of("foo"))),
156+
.withQueryConfiguration(queryConfiguration),
150157
request);
151158

152159
assertThat(updatedSession.id()).isNotNull();
@@ -160,8 +167,7 @@ void update() {
160167
assertThat(updatedSession.timeUpdated()).isAfter(insertedSession.timeUpdated());
161168
assertThat(updatedSession.createdById()).isEqualTo("test-user");
162169
assertThat(updatedSession.lastInteractionTime()).isNull();
163-
assertThat(updatedSession.queryConfiguration())
164-
.isEqualTo(new Types.QueryConfiguration(true, false, true, List.of("foo")));
170+
assertThat(updatedSession.queryConfiguration()).isEqualTo(queryConfiguration);
165171
}
166172

167173
@Test

llm-service/app/ai/vector_stores/opensearch.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,11 @@ def _find_dim(data_source_id: int) -> int:
112112

113113
def size(self) -> Optional[int]:
114114
os_client = self._low_level_client
115-
return int(os_client.count(index=self.table_name)["count"])
115+
try:
116+
return int(os_client.count(index=self.table_name)["count"])
117+
except opensearchpy.exceptions.NotFoundError:
118+
# Return 0 if index doesn't exist yet
119+
return 0
116120

117121
def delete(self) -> None:
118122
os_client = self._low_level_client

llm-service/app/routers/index/sessions/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,18 @@ def generate_stream() -> Generator[str, None, None]:
301301

302302
first_message = True
303303
stream = future.result()
304+
305+
# If streaming is disabled, immediately send a loading event to show StreamedEvents
306+
if session.query_configuration.disable_streaming:
307+
loading = ChatEvent(
308+
type="thinking",
309+
name="thinking",
310+
timestamp=time.time(),
311+
data="Preparing full response...",
312+
)
313+
event_json = json.dumps({"event": loading.model_dump()})
314+
yield f"data: {event_json}\n\n"
315+
first_message = False
304316
for item in stream:
305317
response: ChatResponse = item
306318
# Check for cancellation between each response

llm-service/app/services/caii/caii.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,11 @@ def get_embedding_model(model_name: str) -> BaseEmbedding:
302302

303303

304304
def get_caii_llm_models() -> List[ModelResponse]:
305-
potential_models = get_models_with_task("TEXT_GENERATION")
305+
potential_text_models = get_models_with_task("TEXT_GENERATION")
306+
potential_text_to_text_models = get_models_with_task("TEXT_TO_TEXT_GENERATION")
307+
306308
results: list[Endpoint] = []
307-
for potential in potential_models:
309+
for potential in [*potential_text_models, *potential_text_to_text_models]:
308310
try:
309311
model = get_llm(
310312
endpoint=potential,

llm-service/app/services/chat/streaming_chat.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def stream_chat(
8383
use_hyde=session.query_configuration.enable_hyde,
8484
use_summary_filter=session.query_configuration.enable_summary_filter,
8585
use_tool_calling=session.query_configuration.enable_tool_calling,
86+
use_streaming=not session.query_configuration.disable_streaming,
8687
)
8788

8889
response_id = str(uuid.uuid4())
@@ -188,14 +189,20 @@ def _stream_direct_llm_chat(
188189
user_name: Optional[str],
189190
) -> Generator[ChatResponse, None, None]:
190191
record_direct_llm_mlflow_run(response_id, session, user_name)
191-
192-
chat_response = llm_completion.stream_completion(
193-
session.id, query, session.inference_model
194-
)
195-
response: ChatResponse = ChatResponse(message=ChatMessage(content=query))
196-
for response in chat_response:
192+
response: ChatResponse
193+
if session.query_configuration.disable_streaming:
194+
# Use non-streaming completion when streaming is disabled
195+
response = llm_completion.completion(session.id, query, session.inference_model)
197196
response.additional_kwargs["response_id"] = response_id
198197
yield response
198+
else:
199+
chat_response = llm_completion.stream_completion(
200+
session.id, query, session.inference_model
201+
)
202+
response = ChatResponse(message=ChatMessage(content=query))
203+
for response in chat_response:
204+
response.additional_kwargs["response_id"] = response_id
205+
yield response
199206

200207
new_chat_message = RagStudioChatMessage(
201208
id=response_id,

llm-service/app/services/metadata_apis/session_metadata_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class SessionQueryConfiguration:
5252
enable_summary_filter: bool
5353
enable_tool_calling: bool = False
5454
selected_tools: list[str] = field(default_factory=list)
55+
disable_streaming: bool = False
5556

5657

5758
@dataclass
@@ -126,6 +127,9 @@ def session_from_java_response(data: dict[str, Any]) -> Session:
126127
enable_tool_calling=data["queryConfiguration"].get(
127128
"enableToolCalling", False
128129
),
130+
disable_streaming=data["queryConfiguration"].get(
131+
"disableStreaming", False
132+
),
129133
selected_tools=data["queryConfiguration"]["selectedTools"] or [],
130134
),
131135
associated_data_source_id=data.get("associatedDataSourceId", None),
@@ -146,6 +150,7 @@ def update_session(session: Session, user_name: Optional[str]) -> Session:
146150
"enableSummaryFilter": session.query_configuration.enable_summary_filter,
147151
"enableToolCalling": session.query_configuration.enable_tool_calling,
148152
"selectedTools": session.query_configuration.selected_tools,
153+
"disableStreaming": session.query_configuration.disable_streaming,
149154
},
150155
associatedDataSourceId=session.associated_data_source_id,
151156
)

llm-service/app/services/query/agents/non_streamer_bedrock_converse.py

Lines changed: 0 additions & 97 deletions
This file was deleted.

0 commit comments

Comments
 (0)