From 376693be04e64c171b65f13b19b0165b05fc050c Mon Sep 17 00:00:00 2001 From: Lzzzt Date: Fri, 7 Nov 2025 14:30:57 +0800 Subject: [PATCH] feat: make `LlmGenerationClient::generate` return json fix: handle json value --- rust/cocoindex/src/llm/anthropic.rs | 4 +-- rust/cocoindex/src/llm/bedrock.rs | 4 +-- rust/cocoindex/src/llm/gemini.rs | 32 ++++++++++++++++--- rust/cocoindex/src/llm/mod.rs | 5 +-- rust/cocoindex/src/llm/ollama.rs | 6 ++-- rust/cocoindex/src/llm/openai.rs | 32 +++++++++++++------ .../src/ops/functions/extract_by_llm.rs | 5 ++- 7 files changed, 64 insertions(+), 24 deletions(-) diff --git a/rust/cocoindex/src/llm/anthropic.rs b/rust/cocoindex/src/llm/anthropic.rs index c84e2c48..6cf55a55 100644 --- a/rust/cocoindex/src/llm/anthropic.rs +++ b/rust/cocoindex/src/llm/anthropic.rs @@ -123,7 +123,7 @@ impl LlmGenerationClient for Client { } let text = if let Some(json) = extracted_json { // Try strict JSON serialization first - serde_json::to_string(&json)? + return Ok(LlmGenerateResponse::Json(json)); } else { // Fallback: try text if no tool output found match &mut resp_json["content"][0]["text"] { @@ -155,7 +155,7 @@ impl LlmGenerationClient for Client { } }; - Ok(LlmGenerateResponse { text }) + Ok(LlmGenerateResponse::Text(text)) } fn json_schema_options(&self) -> ToJsonSchemaOptions { diff --git a/rust/cocoindex/src/llm/bedrock.rs b/rust/cocoindex/src/llm/bedrock.rs index 0f5f9903..e96ca55c 100644 --- a/rust/cocoindex/src/llm/bedrock.rs +++ b/rust/cocoindex/src/llm/bedrock.rs @@ -148,7 +148,7 @@ impl LlmGenerationClient for Client { if let Some(json) = extracted_json { // Return the structured output as JSON - serde_json::to_string(&json)? + return Ok(LlmGenerateResponse::Json(json)); } else { // Fall back to text content let mut text_parts = Vec::new(); @@ -165,7 +165,7 @@ impl LlmGenerationClient for Client { return Err(anyhow::anyhow!("No content found in Bedrock response")); }; - Ok(LlmGenerateResponse { text }) + Ok(LlmGenerateResponse::Text(text)) } fn json_schema_options(&self) -> ToJsonSchemaOptions { diff --git a/rust/cocoindex/src/llm/gemini.rs b/rust/cocoindex/src/llm/gemini.rs index 0e6ce1bc..a300b6e8 100644 --- a/rust/cocoindex/src/llm/gemini.rs +++ b/rust/cocoindex/src/llm/gemini.rs @@ -147,8 +147,11 @@ impl LlmGenerationClient for AiStudioClient { }); } + let mut need_json = false; + // If structured output is requested, add schema and responseMimeType if let Some(OutputFormat::JsonSchema { schema, .. }) = &request.output_format { + need_json = true; let mut schema_json = serde_json::to_value(schema)?; remove_additional_properties(&mut schema_json); payload["generationConfig"] = serde_json::json!({ @@ -161,18 +164,24 @@ impl LlmGenerationClient for AiStudioClient { let resp = http::request(|| self.client.post(&url).json(&payload)) .await .context("Gemini API error")?; - let resp_json: Value = resp.json().await.context("Invalid JSON")?; + let mut resp_json: Value = resp.json().await.context("Invalid JSON")?; if let Some(error) = resp_json.get("error") { bail!("Gemini API error: {:?}", error); } - let mut resp_json = resp_json; + + if need_json { + return Ok(super::LlmGenerateResponse::Json(serde_json::json!( + resp_json["candidates"][0] + ))); + } + let text = match &mut resp_json["candidates"][0]["content"]["parts"][0]["text"] { Value::String(s) => std::mem::take(s), _ => bail!("No text in response"), }; - Ok(LlmGenerateResponse { text }) + Ok(LlmGenerateResponse::Text(text)) } fn json_schema_options(&self) -> ToJsonSchemaOptions { @@ -333,9 +342,12 @@ impl LlmGenerationClient for VertexAiClient { .set_parts(vec![Part::new().set_text(sys.to_string())]) }); + let mut need_json = false; + // Compose generation config let mut generation_config = None; if let Some(OutputFormat::JsonSchema { schema, .. }) = &request.output_format { + need_json = true; let schema_json = serde_json::to_value(schema)?; generation_config = Some( GenerationConfig::new() @@ -359,6 +371,18 @@ impl LlmGenerationClient for VertexAiClient { // Call the API let resp = req.send().await?; + + if need_json { + match resp.candidates.into_iter().next() { + Some(resp_json) => { + return Ok(super::LlmGenerateResponse::Json(serde_json::json!( + resp_json + ))); + } + None => bail!("No response"), + } + } + // Extract text from response let Some(Data::Text(text)) = resp .candidates @@ -370,7 +394,7 @@ impl LlmGenerationClient for VertexAiClient { else { bail!("No text in response"); }; - Ok(super::LlmGenerateResponse { text }) + Ok(super::LlmGenerateResponse::Text(text)) } fn json_schema_options(&self) -> ToJsonSchemaOptions { diff --git a/rust/cocoindex/src/llm/mod.rs b/rust/cocoindex/src/llm/mod.rs index 00df51a2..319eec3c 100644 --- a/rust/cocoindex/src/llm/mod.rs +++ b/rust/cocoindex/src/llm/mod.rs @@ -66,8 +66,9 @@ pub struct LlmGenerateRequest<'a> { } #[derive(Debug)] -pub struct LlmGenerateResponse { - pub text: String, +pub enum LlmGenerateResponse { + Text(String), + Json(serde_json::Value), } #[async_trait] diff --git a/rust/cocoindex/src/llm/ollama.rs b/rust/cocoindex/src/llm/ollama.rs index b02a6ddc..fc36ee93 100644 --- a/rust/cocoindex/src/llm/ollama.rs +++ b/rust/cocoindex/src/llm/ollama.rs @@ -108,10 +108,8 @@ impl LlmGenerationClient for Client { }) .await .context("Ollama API error")?; - let json: OllamaResponse = res.json().await?; - Ok(super::LlmGenerateResponse { - text: json.response, - }) + + Ok(super::LlmGenerateResponse::Json(res.json().await?)) } fn json_schema_options(&self) -> super::ToJsonSchemaOptions { diff --git a/rust/cocoindex/src/llm/openai.rs b/rust/cocoindex/src/llm/openai.rs index 67a23be3..a89f9d99 100644 --- a/rust/cocoindex/src/llm/openai.rs +++ b/rust/cocoindex/src/llm/openai.rs @@ -1,4 +1,4 @@ -use crate::prelude::*; +use crate::{llm::OutputFormat, prelude::*}; use base64::prelude::*; use super::{LlmEmbeddingClient, LlmGenerationClient, detect_image_mime_type}; @@ -145,15 +145,29 @@ impl LlmGenerationClient for Client { ) .await?; - // Extract the response text from the first choice - let text = response - .choices - .into_iter() - .next() - .and_then(|choice| choice.message.content) - .ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?; + let mut response_iter = response.choices.into_iter(); - Ok(super::LlmGenerateResponse { text }) + match request.output_format { + Some(OutputFormat::JsonSchema { .. }) => { + // Extract the response json from the first choice + let response_json = serde_json::json!( + response_iter + .next() + .ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))? + ); + + Ok(super::LlmGenerateResponse::Json(response_json)) + } + None => { + // Extract the response text from the first choice + let text = response_iter + .next() + .and_then(|choice| choice.message.content) + .ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?; + + Ok(super::LlmGenerateResponse::Text(text)) + } + } } fn json_schema_options(&self) -> super::ToJsonSchemaOptions { diff --git a/rust/cocoindex/src/ops/functions/extract_by_llm.rs b/rust/cocoindex/src/ops/functions/extract_by_llm.rs index 4dfe9d4d..6ea88722 100644 --- a/rust/cocoindex/src/ops/functions/extract_by_llm.rs +++ b/rust/cocoindex/src/ops/functions/extract_by_llm.rs @@ -113,7 +113,10 @@ impl SimpleFunctionExecutor for Executor { }), }; let res = self.client.generate(req).await?; - let json_value: serde_json::Value = utils::deser::from_json_str(res.text.as_str())?; + let json_value = match res { + crate::llm::LlmGenerateResponse::Text(text) => utils::deser::from_json_str(&text)?, + crate::llm::LlmGenerateResponse::Json(value) => value, + }; let value = self.value_extractor.extract_value(json_value)?; Ok(value) }