From 52ae02bdb5a66981ff9909bf410e1253867a6688 Mon Sep 17 00:00:00 2001 From: Chinmay Bansal Date: Thu, 11 Sep 2025 23:29:19 -0700 Subject: [PATCH 1/5] Add reasoning support for Cohere chat generator --- .../generators/cohere/chat/chat_generator.py | 139 +++++++++- .../cohere/tests/test_chat_generator.py | 254 +++++++++++++++++- 2 files changed, 387 insertions(+), 6 deletions(-) diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index 278e47a1e6..2437698e95 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -1,9 +1,10 @@ import json +import re from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union, get_args from haystack import component, default_from_dict, default_to_dict, logging from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message -from haystack.dataclasses import ChatMessage, ComponentInfo, ImageContent, TextContent, ToolCall +from haystack.dataclasses import ChatMessage, ComponentInfo, ImageContent, ReasoningContent, TextContent, ToolCall from haystack.dataclasses.streaming_chunk import ( AsyncStreamingCallbackT, FinishReason, @@ -202,11 +203,20 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage: ) ) + # Extract reasoning from content if present, even with tool calls + reasoning_content = None + if chat_response.message.content and hasattr(chat_response.message.content[0], "text"): + raw_content = chat_response.message.content[0].text + reasoning_content, _ = _extract_reasoning_from_response(raw_content) + # Create message with tool plan as text and tool calls in the format Haystack expects tool_plan = chat_response.message.tool_plan or "" - message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls) + message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls, reasoning=reasoning_content) elif chat_response.message.content and hasattr(chat_response.message.content[0], "text"): - message = ChatMessage.from_assistant(chat_response.message.content[0].text) + raw_content = chat_response.message.content[0].text + # Extract reasoning content if present + reasoning_content, cleaned_content = _extract_reasoning_from_response(raw_content) + message = ChatMessage.from_assistant(cleaned_content, reasoning=reasoning_content) else: # Handle the case where neither tool_calls nor content exists logger.warning(f"Received empty response from Cohere API: {chat_response.message}") @@ -350,6 +360,125 @@ def _convert_cohere_chunk_to_streaming_chunk( ) +def _extract_reasoning_from_response(response_text: str) -> tuple[Optional[ReasoningContent], str]: + """ + Extract reasoning content from Cohere's response if present. + + Cohere's reasoning-capable models (like Command A Reasoning) may include reasoning content + in various formats. This function attempts to identify and extract such content. + + :param response_text: The raw response text from Cohere + :returns: A tuple of (ReasoningContent or None, cleaned_response_text) + """ + if not response_text or not isinstance(response_text, str): + return None, response_text + + # Check for reasoning markers that Cohere might use + + # Pattern 1: Look for thinking/reasoning tags + thinking_patterns = [ + r"(.*?)", + r"(.*?)", + r"## Reasoning\s*\n(.*?)(?=\n## |$)", + r"## Thinking\s*\n(.*?)(?=\n## |$)", + ] + + for pattern in thinking_patterns: + match = re.search(pattern, response_text, re.DOTALL | re.IGNORECASE) + if match: + reasoning_text = match.group(1).strip() + cleaned_content = re.sub(pattern, "", response_text, flags=re.DOTALL | re.IGNORECASE).strip() + # Apply minimum length threshold for tag-based reasoning + min_reasoning_length = 30 + if len(reasoning_text) > min_reasoning_length: + return ReasoningContent(reasoning_text=reasoning_text), cleaned_content + else: + # Content too short, but still clean the tags + return None, cleaned_content + + # Pattern 2: Look for step-by-step reasoning at start + lines = response_text.split("\n") + reasoning_lines = [] + content_lines = [] + found_separator = False + + for i, line in enumerate(lines): + stripped_line = line.strip() + # Look for reasoning indicators at the beginning of lines (more precise) + if ( + stripped_line.startswith(("Step ", "First,", "Let me think", "I need to solve", "To solve")) + or stripped_line.startswith(("## Reasoning", "## Thinking", "## My reasoning")) + or ( + len(stripped_line) > 0 + and stripped_line.endswith(":") + and ("reasoning" in stripped_line.lower() or "thinking" in stripped_line.lower()) + ) + ): + # Look for a clear separator to determine where reasoning ends + reasoning_end = len(lines) # Default to end of text + for j in range(i + 1, len(lines)): + next_line = lines[j].strip() + if next_line.startswith( + ("Based on", "Therefore", "In conclusion", "So,", "Thus,", "## Solution", "## Answer") + ): + reasoning_end = j + break + + reasoning_lines = lines[:reasoning_end] + content_lines = lines[reasoning_end:] + found_separator = True + break + # Stop looking after first few lines + max_lines_to_check = 10 + if i > max_lines_to_check: + break + + if found_separator and reasoning_lines: + reasoning_text = "\n".join(reasoning_lines).strip() + cleaned_content = "\n".join(content_lines).strip() + min_reasoning_length = 30 + if len(reasoning_text) > min_reasoning_length: # Minimum threshold + return ReasoningContent(reasoning_text=reasoning_text), cleaned_content + + # No reasoning detected + return None, response_text + + +def _convert_streaming_chunks_to_chat_message_with_reasoning(chunks: List[StreamingChunk]) -> ChatMessage: + """ + Convert streaming chunks to ChatMessage with reasoning extraction support. + + This is a custom version of the core utility function that adds reasoning content + extraction for Cohere responses. + """ + # Use the core utility to get the base ChatMessage + base_message = _convert_streaming_chunks_to_chat_message(chunks=chunks) + + # Extract text content to check for reasoning + if not base_message.text: + return base_message + + # Use the text property for reasoning extraction + combined_text = base_message.text + + # Extract reasoning if present + reasoning_content, cleaned_text = _extract_reasoning_from_response(combined_text) + + if reasoning_content is None: + # No reasoning found, return original message + return base_message + + # Create new message with reasoning support + new_message = ChatMessage.from_assistant( + text=cleaned_text, + reasoning=reasoning_content, + tool_calls=base_message.tool_calls, + meta=base_message.meta, + ) + + return new_message + + def _parse_streaming_response( response: Iterator[StreamedChatResponseV2], model: str, @@ -381,7 +510,7 @@ def _parse_streaming_response( chunks.append(streaming_chunk) streaming_callback(streaming_chunk) - return _convert_streaming_chunks_to_chat_message(chunks=chunks) + return _convert_streaming_chunks_to_chat_message_with_reasoning(chunks=chunks) async def _parse_async_streaming_response( @@ -409,7 +538,7 @@ async def _parse_async_streaming_response( chunks.append(streaming_chunk) await streaming_callback(streaming_chunk) - return _convert_streaming_chunks_to_chat_message(chunks=chunks) + return _convert_streaming_chunks_to_chat_message_with_reasoning(chunks=chunks) @component diff --git a/integrations/cohere/tests/test_chat_generator.py b/integrations/cohere/tests/test_chat_generator.py index 4378ee016c..b90e41158f 100644 --- a/integrations/cohere/tests/test_chat_generator.py +++ b/integrations/cohere/tests/test_chat_generator.py @@ -7,13 +7,14 @@ from haystack import Pipeline from haystack.components.generators.utils import print_streaming_chunk from haystack.components.tools import ToolInvoker -from haystack.dataclasses import ChatMessage, ChatRole, ImageContent, ToolCall +from haystack.dataclasses import ChatMessage, ChatRole, ImageContent, ReasoningContent, ToolCall from haystack.dataclasses.streaming_chunk import StreamingChunk from haystack.tools import Tool from haystack.utils import Secret from haystack_integrations.components.generators.cohere import CohereChatGenerator from haystack_integrations.components.generators.cohere.chat.chat_generator import ( + _extract_reasoning_from_response, _format_message, ) @@ -656,3 +657,254 @@ def test_live_run_multimodal(self): assert len(results["replies"]) == 1 assert isinstance(results["replies"][0], ChatMessage) assert len(results["replies"][0].text) > 0 + + +class TestReasoningExtraction: + """Test the reasoning extraction functionality.""" + + def test_extract_reasoning_with_thinking_tags(self): + """Test extraction of reasoning from tags.""" + response_text = """ +I need to calculate the area of a circle. +The formula is π * r². +Given radius is 5, so area = π * 25 = 78.54 + + +The area of a circle with radius 5 is approximately 78.54 square units.""" + + reasoning, cleaned = _extract_reasoning_from_response(response_text) + + assert reasoning is not None + assert isinstance(reasoning, ReasoningContent) + assert "calculate the area of a circle" in reasoning.reasoning_text + assert "formula is π * r²" in reasoning.reasoning_text + assert "area = π * 25 = 78.54" in reasoning.reasoning_text + assert cleaned.strip() == "The area of a circle with radius 5 is approximately 78.54 square units." + + def test_extract_reasoning_with_reasoning_tags(self): + """Test extraction of reasoning from tags.""" + response_text = """ +Let me think about this step by step: +1. First, I need to understand the problem +2. Then identify the key variables +3. Apply the appropriate formula + + +Based on my analysis, here's the solution.""" + + reasoning, cleaned = _extract_reasoning_from_response(response_text) + + assert reasoning is not None + assert isinstance(reasoning, ReasoningContent) + assert "step by step" in reasoning.reasoning_text + assert "understand the problem" in reasoning.reasoning_text + assert "key variables" in reasoning.reasoning_text + assert cleaned.strip() == "Based on my analysis, here's the solution." + + def test_extract_reasoning_with_step_by_step_headers(self): + """Test extraction of reasoning from step-by-step format.""" + response_text = """## My reasoning: +Step 1: Analyze the input data +Step 2: Identify patterns +Step 3: Apply the algorithm + +## Solution: +The final answer is 42.""" + + reasoning, cleaned = _extract_reasoning_from_response(response_text) + + assert reasoning is not None + assert isinstance(reasoning, ReasoningContent) + assert "Step 1: Analyze the input data" in reasoning.reasoning_text + assert "Step 2: Identify patterns" in reasoning.reasoning_text + assert "Step 3: Apply the algorithm" in reasoning.reasoning_text + assert cleaned.strip() == "## Solution:\nThe final answer is 42." + + def test_extract_reasoning_no_reasoning_present(self): + """Test that no reasoning is extracted when none is present.""" + response_text = "This is a simple response without any reasoning content." + + reasoning, cleaned = _extract_reasoning_from_response(response_text) + + assert reasoning is None + assert cleaned == response_text + + def test_extract_reasoning_short_reasoning_ignored(self): + """Test that very short reasoning content is ignored.""" + response_text = """ +OK + + +The answer is yes.""" + + reasoning, cleaned = _extract_reasoning_from_response(response_text) + + assert reasoning is None # Too short, should be ignored + assert cleaned.strip() == "The answer is yes." + + def test_extract_reasoning_with_let_me_think(self): + """Test extraction of reasoning starting with 'Let me think'.""" + response_text = """Let me think through this carefully: + +First, I need to consider the constraints of the problem. The user is asking about quantum +mechanics, which requires understanding wave-particle duality. + +Second, I should explain the fundamental principles clearly. + +Based on this analysis, quantum mechanics describes the behavior of matter and energy at the atomic scale.""" + + reasoning, cleaned = _extract_reasoning_from_response(response_text) + + assert reasoning is not None + assert isinstance(reasoning, ReasoningContent) + assert "think through this carefully" in reasoning.reasoning_text + assert "constraints of the problem" in reasoning.reasoning_text + assert "wave-particle duality" in reasoning.reasoning_text + assert cleaned.strip() == ( + "Based on this analysis, quantum mechanics describes the behavior of matter and energy at the atomic scale." + ) + + +class TestCohereChatGeneratorReasoning: + """Integration tests for reasoning functionality in CohereChatGenerator.""" + + @pytest.mark.skipif(not os.environ.get("COHERE_API_KEY"), reason="COHERE_API_KEY not set") + @pytest.mark.integration + def test_reasoning_with_command_a_reasoning_model(self): + """Test reasoning extraction with Command A Reasoning model.""" + generator = CohereChatGenerator( + model="command-a-reasoning-111b-2024-10-03", + generation_kwargs={"thinking": True}, # Enable reasoning + ) + + messages = [ + ChatMessage.from_user("Solve this math problem step by step: What is the area of a circle with radius 7?") + ] + + result = generator.run(messages=messages) + + assert "replies" in result + assert len(result["replies"]) == 1 + + reply = result["replies"][0] + assert isinstance(reply, ChatMessage) + assert reply.role == ChatRole.ASSISTANT + + # Check if reasoning was extracted + if reply.reasoning: + assert isinstance(reply.reasoning, ReasoningContent) + assert len(reply.reasoning.reasoning_text) > 50 # Should have substantial reasoning + + # The reasoning should contain mathematical thinking + reasoning_lower = reply.reasoning.reasoning_text.lower() + assert any(word in reasoning_lower for word in ["area", "circle", "radius", "formula", "π", "pi"]) + + # Check the main response content + assert len(reply.text) > 0 + response_lower = reply.text.lower() + assert any(word in response_lower for word in ["area", "153.94", "154", "square"]) + + def test_reasoning_with_mock_response(self): + """Test reasoning extraction with mocked Cohere response.""" + generator = CohereChatGenerator( + model="command-a-reasoning-111b-2024-10-03", api_key=Secret.from_token("fake-api-key") + ) + + # Mock the Cohere client response + mock_response = MagicMock() + mock_response.message.content = [ + MagicMock( + text=""" +I need to solve for the area of a circle. +The formula is A = πr² +With radius 7: A = π * 7² = π * 49 ≈ 153.94 + + +The area of a circle with radius 7 is approximately 153.94 square units.""" + ) + ] + mock_response.message.tool_calls = None + mock_response.message.citations = None + + generator.client.chat = MagicMock(return_value=mock_response) + + messages = [ChatMessage.from_user("What is the area of a circle with radius 7?")] + result = generator.run(messages=messages) + + assert "replies" in result + assert len(result["replies"]) == 1 + + reply = result["replies"][0] + assert isinstance(reply, ChatMessage) + assert reply.role == ChatRole.ASSISTANT + + # Check reasoning extraction + assert reply.reasoning is not None + assert isinstance(reply.reasoning, ReasoningContent) + assert "formula is A = πr²" in reply.reasoning.reasoning_text + assert "π * 49 ≈ 153.94" in reply.reasoning.reasoning_text + + # Check cleaned content + assert reply.text.strip() == "The area of a circle with radius 7 is approximately 153.94 square units." + + def test_reasoning_with_tool_calls_compatibility(self): + """Test that reasoning works with tool calls.""" + weather_tool = Tool( + name="weather", + description="Get weather for a city", + parameters={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + function=weather, + ) + + generator = CohereChatGenerator( + model="command-a-reasoning-111b-2024-10-03", tools=[weather_tool], api_key=Secret.from_token("fake-api-key") + ) + + # Mock response with both reasoning and tool calls + mock_response = MagicMock() + mock_response.message.content = [ + MagicMock( + text=""" +The user is asking about weather in Paris. I should use the weather tool to get accurate information. + + +I'll check the weather in Paris for you.""" + ) + ] + + # Mock tool call + mock_tool_call = MagicMock() + mock_tool_call.function.name = "weather" + mock_tool_call.function.arguments = '{"city": "Paris"}' + mock_tool_call.id = "call_123" + mock_response.message.tool_calls = [mock_tool_call] + mock_response.message.tool_plan = "I'll check the weather in Paris for you." + mock_response.message.citations = None + + generator.client.chat = MagicMock(return_value=mock_response) + + messages = [ChatMessage.from_user("What's the weather like in Paris?")] + result = generator.run(messages=messages) + + assert "replies" in result + assert len(result["replies"]) == 1 + + reply = result["replies"][0] + assert isinstance(reply, ChatMessage) + + # Check reasoning extraction + assert reply.reasoning is not None + assert isinstance(reply.reasoning, ReasoningContent) + assert "weather tool" in reply.reasoning.reasoning_text + + # Check tool calls are preserved + assert reply.tool_calls is not None + assert len(reply.tool_calls) == 1 + assert reply.tool_calls[0].tool_name == "weather" + + # Check cleaned content + assert "I'll check the weather in Paris" in reply.text From 974c60a13884e93490deae2afc507b496991a654 Mon Sep 17 00:00:00 2001 From: Chinmay Bansal Date: Fri, 14 Nov 2025 14:00:25 -0800 Subject: [PATCH 2/5] resolve conflict --- .../components/generators/cohere/chat/chat_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index 6c0f6ea293..7a90822a32 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -446,7 +446,7 @@ def _extract_reasoning_from_response(response_text: str) -> tuple[Optional[Reaso return None, response_text -def _convert_streaming_chunks_to_chat_message_with_reasoning(chunks: List[StreamingChunk]) -> ChatMessage: +def _convert_streaming_chunks_to_chat_message_with_reasoning(chunks: list[StreamingChunk]) -> ChatMessage: """ Convert streaming chunks to ChatMessage with reasoning extraction support. From 57ebd79c2b91db85674e33edc4658f5b8300421b Mon Sep 17 00:00:00 2001 From: Chinmay Bansal Date: Fri, 14 Nov 2025 14:25:45 -0800 Subject: [PATCH 3/5] address PR feedback --- .../generators/cohere/chat/chat_generator.py | 91 +++++------ .../cohere/tests/test_chat_generator.py | 151 +++++------------- 2 files changed, 84 insertions(+), 158 deletions(-) diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index 7a90822a32..566b372f1a 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -185,6 +185,7 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage: Extracts and organizes various response components including: - Text content - Tool calls + - Reasoning content (via native Cohere API with text-based fallback) - Usage statistics - Citations - Metadata @@ -193,6 +194,29 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage: :param model: The name of the model that generated the response. :return: A Haystack ChatMessage containing the formatted response. """ + # Try to extract reasoning content using Cohere's native API (preferred method) + reasoning_content = None + text_content = "" + + if chat_response.message.content: + for content_item in chat_response.message.content: + # Access thinking content via native Cohere API + if hasattr(content_item, "type") and content_item.type == "thinking": + if hasattr(content_item, "thinking") and content_item.thinking: + reasoning_content = ReasoningContent(reasoning_text=content_item.thinking) + # Access text content + elif hasattr(content_item, "type") and content_item.type == "text": + if hasattr(content_item, "text") and content_item.text: + text_content = content_item.text + + # Fallback: If reasoning wasn't found via native API but text contains reasoning markers, + # extract it from text (for backward compatibility) + if reasoning_content is None and text_content: + fallback_reasoning, cleaned_text = _extract_reasoning_from_text(text_content) + if fallback_reasoning is not None: + reasoning_content = fallback_reasoning + text_content = cleaned_text + if chat_response.message.tool_calls: tool_calls = [] for tc in chat_response.message.tool_calls: @@ -205,20 +229,11 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage: ) ) - # Extract reasoning from content if present, even with tool calls - reasoning_content = None - if chat_response.message.content and hasattr(chat_response.message.content[0], "text"): - raw_content = chat_response.message.content[0].text - reasoning_content, _ = _extract_reasoning_from_response(raw_content) - # Create message with tool plan as text and tool calls in the format Haystack expects tool_plan = chat_response.message.tool_plan or "" message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls, reasoning=reasoning_content) - elif chat_response.message.content and hasattr(chat_response.message.content[0], "text"): - raw_content = chat_response.message.content[0].text - # Extract reasoning content if present - reasoning_content, cleaned_content = _extract_reasoning_from_response(raw_content) - message = ChatMessage.from_assistant(cleaned_content, reasoning=reasoning_content) + elif text_content: + message = ChatMessage.from_assistant(text_content, reasoning=reasoning_content) else: # Handle the case where neither tool_calls nor content exists logger.warning(f"Received empty response from Cohere API: {chat_response.message}") @@ -362,12 +377,12 @@ def _convert_cohere_chunk_to_streaming_chunk( ) -def _extract_reasoning_from_response(response_text: str) -> tuple[Optional[ReasoningContent], str]: +def _extract_reasoning_from_text(response_text: str) -> tuple[Optional[ReasoningContent], str]: """ - Extract reasoning content from Cohere's response if present. + Extract reasoning content from text as a fallback method. - Cohere's reasoning-capable models (like Command A Reasoning) may include reasoning content - in various formats. This function attempts to identify and extract such content. + This is used when reasoning is not available via native Cohere API + (e.g., in streaming mode or for backward compatibility). :param response_text: The raw response text from Cohere :returns: A tuple of (ReasoningContent or None, cleaned_response_text) @@ -375,8 +390,6 @@ def _extract_reasoning_from_response(response_text: str) -> tuple[Optional[Reaso if not response_text or not isinstance(response_text, str): return None, response_text - # Check for reasoning markers that Cohere might use - # Pattern 1: Look for thinking/reasoning tags thinking_patterns = [ r"(.*?)", @@ -390,23 +403,17 @@ def _extract_reasoning_from_response(response_text: str) -> tuple[Optional[Reaso if match: reasoning_text = match.group(1).strip() cleaned_content = re.sub(pattern, "", response_text, flags=re.DOTALL | re.IGNORECASE).strip() - # Apply minimum length threshold for tag-based reasoning min_reasoning_length = 30 if len(reasoning_text) > min_reasoning_length: return ReasoningContent(reasoning_text=reasoning_text), cleaned_content else: - # Content too short, but still clean the tags return None, cleaned_content # Pattern 2: Look for step-by-step reasoning at start lines = response_text.split("\n") - reasoning_lines = [] - content_lines = [] - found_separator = False - + max_lines_to_check = 10 for i, line in enumerate(lines): stripped_line = line.strip() - # Look for reasoning indicators at the beginning of lines (more precise) if ( stripped_line.startswith(("Step ", "First,", "Let me think", "I need to solve", "To solve")) or stripped_line.startswith(("## Reasoning", "## Thinking", "## My reasoning")) @@ -416,8 +423,7 @@ def _extract_reasoning_from_response(response_text: str) -> tuple[Optional[Reaso and ("reasoning" in stripped_line.lower() or "thinking" in stripped_line.lower()) ) ): - # Look for a clear separator to determine where reasoning ends - reasoning_end = len(lines) # Default to end of text + reasoning_end = len(lines) for j in range(i + 1, len(lines)): next_line = lines[j].strip() if next_line.startswith( @@ -428,21 +434,16 @@ def _extract_reasoning_from_response(response_text: str) -> tuple[Optional[Reaso reasoning_lines = lines[:reasoning_end] content_lines = lines[reasoning_end:] - found_separator = True - break - # Stop looking after first few lines - max_lines_to_check = 10 - if i > max_lines_to_check: + reasoning_text = "\n".join(reasoning_lines).strip() + cleaned_content = "\n".join(content_lines).strip() + min_reasoning_length = 30 + if len(reasoning_text) > min_reasoning_length: + return ReasoningContent(reasoning_text=reasoning_text), cleaned_content break - if found_separator and reasoning_lines: - reasoning_text = "\n".join(reasoning_lines).strip() - cleaned_content = "\n".join(content_lines).strip() - min_reasoning_length = 30 - if len(reasoning_text) > min_reasoning_length: # Minimum threshold - return ReasoningContent(reasoning_text=reasoning_text), cleaned_content + if i > max_lines_to_check: # Stop looking after first few lines + break - # No reasoning detected return None, response_text @@ -450,27 +451,19 @@ def _convert_streaming_chunks_to_chat_message_with_reasoning(chunks: list[Stream """ Convert streaming chunks to ChatMessage with reasoning extraction support. - This is a custom version of the core utility function that adds reasoning content - extraction for Cohere responses. + For streaming, reasoning might not come via native API, so we use text-based extraction. """ - # Use the core utility to get the base ChatMessage base_message = _convert_streaming_chunks_to_chat_message(chunks=chunks) - # Extract text content to check for reasoning if not base_message.text: return base_message - # Use the text property for reasoning extraction - combined_text = base_message.text - - # Extract reasoning if present - reasoning_content, cleaned_text = _extract_reasoning_from_response(combined_text) + # Try to extract reasoning from text (fallback for streaming) + reasoning_content, cleaned_text = _extract_reasoning_from_text(base_message.text) if reasoning_content is None: - # No reasoning found, return original message return base_message - # Create new message with reasoning support new_message = ChatMessage.from_assistant( text=cleaned_text, reasoning=reasoning_content, diff --git a/integrations/cohere/tests/test_chat_generator.py b/integrations/cohere/tests/test_chat_generator.py index de151b1163..21076e87a4 100644 --- a/integrations/cohere/tests/test_chat_generator.py +++ b/integrations/cohere/tests/test_chat_generator.py @@ -14,7 +14,7 @@ from haystack_integrations.components.generators.cohere import CohereChatGenerator from haystack_integrations.components.generators.cohere.chat.chat_generator import ( - _extract_reasoning_from_response, + _extract_reasoning_from_text, _format_message, ) @@ -445,11 +445,14 @@ def test_run_image(self): generator = CohereChatGenerator(api_key=Secret.from_token("test-api-key")) - # Mock the client's chat method + # Mock the client's chat method with proper content structure mock_response = MagicMock() - mock_response.message.content = [MagicMock()] - mock_response.message.content[0].text = "This is a test image response" + text_content = MagicMock() + text_content.type = "text" + text_content.text = "This is a test image response" + mock_response.message.content = [text_content] mock_response.message.tool_calls = None + mock_response.message.citations = None mock_response.finish_reason = "COMPLETE" mock_response.usage = None @@ -729,11 +732,11 @@ def test_live_run_multimodal(self): assert len(results["replies"][0].text) > 0 -class TestReasoningExtraction: - """Test the reasoning extraction functionality.""" +class TestReasoningTextExtraction: + """Test the fallback text-based reasoning extraction functionality.""" def test_extract_reasoning_with_thinking_tags(self): - """Test extraction of reasoning from tags.""" + """Test extraction of reasoning from tags (fallback method).""" response_text = """ I need to calculate the area of a circle. The formula is π * r². @@ -742,98 +745,22 @@ def test_extract_reasoning_with_thinking_tags(self): The area of a circle with radius 5 is approximately 78.54 square units.""" - reasoning, cleaned = _extract_reasoning_from_response(response_text) + reasoning, cleaned = _extract_reasoning_from_text(response_text) assert reasoning is not None assert isinstance(reasoning, ReasoningContent) assert "calculate the area of a circle" in reasoning.reasoning_text - assert "formula is π * r²" in reasoning.reasoning_text - assert "area = π * 25 = 78.54" in reasoning.reasoning_text assert cleaned.strip() == "The area of a circle with radius 5 is approximately 78.54 square units." - def test_extract_reasoning_with_reasoning_tags(self): - """Test extraction of reasoning from tags.""" - response_text = """ -Let me think about this step by step: -1. First, I need to understand the problem -2. Then identify the key variables -3. Apply the appropriate formula - - -Based on my analysis, here's the solution.""" - - reasoning, cleaned = _extract_reasoning_from_response(response_text) - - assert reasoning is not None - assert isinstance(reasoning, ReasoningContent) - assert "step by step" in reasoning.reasoning_text - assert "understand the problem" in reasoning.reasoning_text - assert "key variables" in reasoning.reasoning_text - assert cleaned.strip() == "Based on my analysis, here's the solution." - - def test_extract_reasoning_with_step_by_step_headers(self): - """Test extraction of reasoning from step-by-step format.""" - response_text = """## My reasoning: -Step 1: Analyze the input data -Step 2: Identify patterns -Step 3: Apply the algorithm - -## Solution: -The final answer is 42.""" - - reasoning, cleaned = _extract_reasoning_from_response(response_text) - - assert reasoning is not None - assert isinstance(reasoning, ReasoningContent) - assert "Step 1: Analyze the input data" in reasoning.reasoning_text - assert "Step 2: Identify patterns" in reasoning.reasoning_text - assert "Step 3: Apply the algorithm" in reasoning.reasoning_text - assert cleaned.strip() == "## Solution:\nThe final answer is 42." - def test_extract_reasoning_no_reasoning_present(self): """Test that no reasoning is extracted when none is present.""" response_text = "This is a simple response without any reasoning content." - reasoning, cleaned = _extract_reasoning_from_response(response_text) + reasoning, cleaned = _extract_reasoning_from_text(response_text) assert reasoning is None assert cleaned == response_text - def test_extract_reasoning_short_reasoning_ignored(self): - """Test that very short reasoning content is ignored.""" - response_text = """ -OK - - -The answer is yes.""" - - reasoning, cleaned = _extract_reasoning_from_response(response_text) - - assert reasoning is None # Too short, should be ignored - assert cleaned.strip() == "The answer is yes." - - def test_extract_reasoning_with_let_me_think(self): - """Test extraction of reasoning starting with 'Let me think'.""" - response_text = """Let me think through this carefully: - -First, I need to consider the constraints of the problem. The user is asking about quantum -mechanics, which requires understanding wave-particle duality. - -Second, I should explain the fundamental principles clearly. - -Based on this analysis, quantum mechanics describes the behavior of matter and energy at the atomic scale.""" - - reasoning, cleaned = _extract_reasoning_from_response(response_text) - - assert reasoning is not None - assert isinstance(reasoning, ReasoningContent) - assert "think through this carefully" in reasoning.reasoning_text - assert "constraints of the problem" in reasoning.reasoning_text - assert "wave-particle duality" in reasoning.reasoning_text - assert cleaned.strip() == ( - "Based on this analysis, quantum mechanics describes the behavior of matter and energy at the atomic scale." - ) - class TestCohereChatGeneratorReasoning: """Integration tests for reasoning functionality in CohereChatGenerator.""" @@ -875,26 +802,30 @@ def test_reasoning_with_command_a_reasoning_model(self): assert any(word in response_lower for word in ["area", "153.94", "154", "square"]) def test_reasoning_with_mock_response(self): - """Test reasoning extraction with mocked Cohere response.""" + """Test reasoning extraction with mocked Cohere response using native API.""" generator = CohereChatGenerator( model="command-a-reasoning-111b-2024-10-03", api_key=Secret.from_token("fake-api-key") ) - # Mock the Cohere client response + # Mock the Cohere client response using native API structure mock_response = MagicMock() - mock_response.message.content = [ - MagicMock( - text=""" -I need to solve for the area of a circle. + + # Create mock content items with thinking and text types + thinking_content = MagicMock() + thinking_content.type = "thinking" + thinking_content.thinking = """I need to solve for the area of a circle. The formula is A = πr² -With radius 7: A = π * 7² = π * 49 ≈ 153.94 - +With radius 7: A = π * 7² = π * 49 ≈ 153.94""" -The area of a circle with radius 7 is approximately 153.94 square units.""" - ) - ] + text_content = MagicMock() + text_content.type = "text" + text_content.text = "The area of a circle with radius 7 is approximately 153.94 square units." + + mock_response.message.content = [thinking_content, text_content] mock_response.message.tool_calls = None mock_response.message.citations = None + mock_response.finish_reason = "COMPLETE" + mock_response.usage = None generator.client.chat = MagicMock(return_value=mock_response) @@ -908,13 +839,13 @@ def test_reasoning_with_mock_response(self): assert isinstance(reply, ChatMessage) assert reply.role == ChatRole.ASSISTANT - # Check reasoning extraction + # Check reasoning extraction via native API assert reply.reasoning is not None assert isinstance(reply.reasoning, ReasoningContent) assert "formula is A = πr²" in reply.reasoning.reasoning_text assert "π * 49 ≈ 153.94" in reply.reasoning.reasoning_text - # Check cleaned content + # Check text content assert reply.text.strip() == "The area of a circle with radius 7 is approximately 153.94 square units." def test_reasoning_with_tool_calls_compatibility(self): @@ -934,17 +865,17 @@ def test_reasoning_with_tool_calls_compatibility(self): model="command-a-reasoning-111b-2024-10-03", tools=[weather_tool], api_key=Secret.from_token("fake-api-key") ) - # Mock response with both reasoning and tool calls + # Mock response with both reasoning and tool calls using native API mock_response = MagicMock() - mock_response.message.content = [ - MagicMock( - text=""" -The user is asking about weather in Paris. I should use the weather tool to get accurate information. - -I'll check the weather in Paris for you.""" - ) - ] + # Create mock content items with thinking type + thinking_content = MagicMock() + thinking_content.type = "thinking" + thinking_content.thinking = ( + "The user is asking about weather in Paris. I should use the weather tool to get accurate information." + ) + + mock_response.message.content = [thinking_content] # Mock tool call mock_tool_call = MagicMock() @@ -954,6 +885,8 @@ def test_reasoning_with_tool_calls_compatibility(self): mock_response.message.tool_calls = [mock_tool_call] mock_response.message.tool_plan = "I'll check the weather in Paris for you." mock_response.message.citations = None + mock_response.finish_reason = "TOOL_CALLS" + mock_response.usage = None generator.client.chat = MagicMock(return_value=mock_response) @@ -966,7 +899,7 @@ def test_reasoning_with_tool_calls_compatibility(self): reply = result["replies"][0] assert isinstance(reply, ChatMessage) - # Check reasoning extraction + # Check reasoning extraction via native API assert reply.reasoning is not None assert isinstance(reply.reasoning, ReasoningContent) assert "weather tool" in reply.reasoning.reasoning_text @@ -976,7 +909,7 @@ def test_reasoning_with_tool_calls_compatibility(self): assert len(reply.tool_calls) == 1 assert reply.tool_calls[0].tool_name == "weather" - # Check cleaned content + # Check tool plan is used as text assert "I'll check the weather in Paris" in reply.text @pytest.mark.skipif(not os.environ.get("COHERE_API_KEY"), reason="COHERE_API_KEY not set") From beb56f0e63bfd26452b75b33f809fc444f783313 Mon Sep 17 00:00:00 2001 From: Chinmay Bansal Date: Mon, 17 Nov 2025 20:30:20 -0800 Subject: [PATCH 4/5] Removed fallback logic --- .../embedders/cohere/document_embedder.py | 6 +- .../cohere/document_image_embedder.py | 4 +- .../generators/cohere/chat/chat_generator.py | 156 +++++------------- .../cohere/tests/test_chat_generator.py | 35 +--- 4 files changed, 46 insertions(+), 155 deletions(-) diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py index 020aaa41fd..b7b77b60b2 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py @@ -22,14 +22,16 @@ class CohereDocumentEmbedder: Usage example: ```python from haystack import Document - from haystack_integrations.components.embedders.cohere import CohereDocumentEmbedder + from haystack_integrations.components.embedders.cohere import ( + CohereDocumentEmbedder, + ) doc = Document(content="I love pizza!") document_embedder = CohereDocumentEmbedder() result = document_embedder.run([doc]) - print(result['documents'][0].embedding) + print(result["documents"][0].embedding) # [-0.453125, 1.2236328, 2.0058594, ...] ``` diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_image_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_image_embedder.py index 1484365819..7b5316eb99 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_image_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_image_embedder.py @@ -37,7 +37,9 @@ class CohereDocumentImageEmbedder: ### Usage example ```python from haystack import Document - from haystack_integrations.components.embedders.cohere import CohereDocumentImageEmbedder + from haystack_integrations.components.embedders.cohere import ( + CohereDocumentImageEmbedder, + ) embedder = CohereDocumentImageEmbedder(model="embed-v4.0") diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index 566b372f1a..47aaac8718 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -1,5 +1,4 @@ import json -import re from collections.abc import AsyncIterator, Iterator from typing import Any, Literal, Optional, Union, get_args @@ -185,7 +184,7 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage: Extracts and organizes various response components including: - Text content - Tool calls - - Reasoning content (via native Cohere API with text-based fallback) + - Reasoning content (via native Cohere API) - Usage statistics - Citations - Metadata @@ -194,7 +193,7 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage: :param model: The name of the model that generated the response. :return: A Haystack ChatMessage containing the formatted response. """ - # Try to extract reasoning content using Cohere's native API (preferred method) + # Extract reasoning content using Cohere's native API reasoning_content = None text_content = "" @@ -209,14 +208,6 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage: if hasattr(content_item, "text") and content_item.text: text_content = content_item.text - # Fallback: If reasoning wasn't found via native API but text contains reasoning markers, - # extract it from text (for backward compatibility) - if reasoning_content is None and text_content: - fallback_reasoning, cleaned_text = _extract_reasoning_from_text(text_content) - if fallback_reasoning is not None: - reasoning_content = fallback_reasoning - text_content = cleaned_text - if chat_response.message.tool_calls: tool_calls = [] for tc in chat_response.message.tool_calls: @@ -377,103 +368,6 @@ def _convert_cohere_chunk_to_streaming_chunk( ) -def _extract_reasoning_from_text(response_text: str) -> tuple[Optional[ReasoningContent], str]: - """ - Extract reasoning content from text as a fallback method. - - This is used when reasoning is not available via native Cohere API - (e.g., in streaming mode or for backward compatibility). - - :param response_text: The raw response text from Cohere - :returns: A tuple of (ReasoningContent or None, cleaned_response_text) - """ - if not response_text or not isinstance(response_text, str): - return None, response_text - - # Pattern 1: Look for thinking/reasoning tags - thinking_patterns = [ - r"(.*?)", - r"(.*?)", - r"## Reasoning\s*\n(.*?)(?=\n## |$)", - r"## Thinking\s*\n(.*?)(?=\n## |$)", - ] - - for pattern in thinking_patterns: - match = re.search(pattern, response_text, re.DOTALL | re.IGNORECASE) - if match: - reasoning_text = match.group(1).strip() - cleaned_content = re.sub(pattern, "", response_text, flags=re.DOTALL | re.IGNORECASE).strip() - min_reasoning_length = 30 - if len(reasoning_text) > min_reasoning_length: - return ReasoningContent(reasoning_text=reasoning_text), cleaned_content - else: - return None, cleaned_content - - # Pattern 2: Look for step-by-step reasoning at start - lines = response_text.split("\n") - max_lines_to_check = 10 - for i, line in enumerate(lines): - stripped_line = line.strip() - if ( - stripped_line.startswith(("Step ", "First,", "Let me think", "I need to solve", "To solve")) - or stripped_line.startswith(("## Reasoning", "## Thinking", "## My reasoning")) - or ( - len(stripped_line) > 0 - and stripped_line.endswith(":") - and ("reasoning" in stripped_line.lower() or "thinking" in stripped_line.lower()) - ) - ): - reasoning_end = len(lines) - for j in range(i + 1, len(lines)): - next_line = lines[j].strip() - if next_line.startswith( - ("Based on", "Therefore", "In conclusion", "So,", "Thus,", "## Solution", "## Answer") - ): - reasoning_end = j - break - - reasoning_lines = lines[:reasoning_end] - content_lines = lines[reasoning_end:] - reasoning_text = "\n".join(reasoning_lines).strip() - cleaned_content = "\n".join(content_lines).strip() - min_reasoning_length = 30 - if len(reasoning_text) > min_reasoning_length: - return ReasoningContent(reasoning_text=reasoning_text), cleaned_content - break - - if i > max_lines_to_check: # Stop looking after first few lines - break - - return None, response_text - - -def _convert_streaming_chunks_to_chat_message_with_reasoning(chunks: list[StreamingChunk]) -> ChatMessage: - """ - Convert streaming chunks to ChatMessage with reasoning extraction support. - - For streaming, reasoning might not come via native API, so we use text-based extraction. - """ - base_message = _convert_streaming_chunks_to_chat_message(chunks=chunks) - - if not base_message.text: - return base_message - - # Try to extract reasoning from text (fallback for streaming) - reasoning_content, cleaned_text = _extract_reasoning_from_text(base_message.text) - - if reasoning_content is None: - return base_message - - new_message = ChatMessage.from_assistant( - text=cleaned_text, - reasoning=reasoning_content, - tool_calls=base_message.tool_calls, - meta=base_message.meta, - ) - - return new_message - - def _parse_streaming_response( response: Iterator[StreamedChatResponseV2], model: str, @@ -505,7 +399,7 @@ def _parse_streaming_response( chunks.append(streaming_chunk) streaming_callback(streaming_chunk) - return _convert_streaming_chunks_to_chat_message_with_reasoning(chunks=chunks) + return _convert_streaming_chunks_to_chat_message(chunks=chunks) async def _parse_async_streaming_response( @@ -533,7 +427,7 @@ async def _parse_async_streaming_response( chunks.append(streaming_chunk) await streaming_callback(streaming_chunk) - return _convert_streaming_chunks_to_chat_message_with_reasoning(chunks=chunks) + return _convert_streaming_chunks_to_chat_message(chunks=chunks) @component @@ -559,9 +453,13 @@ class CohereChatGenerator: ```python from haystack.dataclasses import ChatMessage from haystack.utils import Secret - from haystack_integrations.components.generators.cohere import CohereChatGenerator + from haystack_integrations.components.generators.cohere import ( + CohereChatGenerator, + ) - client = CohereChatGenerator(model="command-r-08-2024", api_key=Secret.from_env_var("COHERE_API_KEY")) + client = CohereChatGenerator( + model="command-r-08-2024", api_key=Secret.from_env_var("COHERE_API_KEY") + ) messages = [ChatMessage.from_user("What's Natural Language Processing?")] client.run(messages) @@ -573,16 +471,25 @@ class CohereChatGenerator: ```python from haystack.dataclasses import ChatMessage, ImageContent from haystack.utils import Secret - from haystack_integrations.components.generators.cohere import CohereChatGenerator + from haystack_integrations.components.generators.cohere import ( + CohereChatGenerator, + ) # Create an image from file path or base64 image_content = ImageContent.from_file_path("path/to/your/image.jpg") # Create a multimodal message with both text and image - messages = [ChatMessage.from_user(content_parts=["What's in this image?", image_content])] + messages = [ + ChatMessage.from_user( + content_parts=["What's in this image?", image_content] + ) + ] # Use a multimodal model like Command A Vision - client = CohereChatGenerator(model="command-a-vision-07-2025", api_key=Secret.from_env_var("COHERE_API_KEY")) + client = CohereChatGenerator( + model="command-a-vision-07-2025", + api_key=Secret.from_env_var("COHERE_API_KEY"), + ) response = client.run(messages) print(response) ``` @@ -597,12 +504,16 @@ class CohereChatGenerator: from haystack.dataclasses import ChatMessage from haystack.components.tools import ToolInvoker from haystack.tools import Tool - from haystack_integrations.components.generators.cohere import CohereChatGenerator + from haystack_integrations.components.generators.cohere import ( + CohereChatGenerator, + ) + # Create a weather tool def weather(city: str) -> str: return f"The weather in {city} is sunny and 32°C" + weather_tool = Tool( name="weather", description="useful to determine the weather in a given location", @@ -621,13 +532,22 @@ def weather(city: str) -> str: # Create and set up the pipeline pipeline = Pipeline() - pipeline.add_component("generator", CohereChatGenerator(model="command-r-08-2024", tools=[weather_tool])) + pipeline.add_component( + "generator", + CohereChatGenerator(model="command-r-08-2024", tools=[weather_tool]), + ) pipeline.add_component("tool_invoker", ToolInvoker(tools=[weather_tool])) pipeline.connect("generator", "tool_invoker") # Run the pipeline with a weather query results = pipeline.run( - data={"generator": {"messages": [ChatMessage.from_user("What's the weather like in Paris?")]}} + data={ + "generator": { + "messages": [ + ChatMessage.from_user("What's the weather like in Paris?") + ] + } + } ) # The tool result will be available in the pipeline output diff --git a/integrations/cohere/tests/test_chat_generator.py b/integrations/cohere/tests/test_chat_generator.py index 21076e87a4..74153fa870 100644 --- a/integrations/cohere/tests/test_chat_generator.py +++ b/integrations/cohere/tests/test_chat_generator.py @@ -13,10 +13,7 @@ from haystack.utils import Secret from haystack_integrations.components.generators.cohere import CohereChatGenerator -from haystack_integrations.components.generators.cohere.chat.chat_generator import ( - _extract_reasoning_from_text, - _format_message, -) +from haystack_integrations.components.generators.cohere.chat.chat_generator import _format_message def weather(city: str) -> str: @@ -732,36 +729,6 @@ def test_live_run_multimodal(self): assert len(results["replies"][0].text) > 0 -class TestReasoningTextExtraction: - """Test the fallback text-based reasoning extraction functionality.""" - - def test_extract_reasoning_with_thinking_tags(self): - """Test extraction of reasoning from tags (fallback method).""" - response_text = """ -I need to calculate the area of a circle. -The formula is π * r². -Given radius is 5, so area = π * 25 = 78.54 - - -The area of a circle with radius 5 is approximately 78.54 square units.""" - - reasoning, cleaned = _extract_reasoning_from_text(response_text) - - assert reasoning is not None - assert isinstance(reasoning, ReasoningContent) - assert "calculate the area of a circle" in reasoning.reasoning_text - assert cleaned.strip() == "The area of a circle with radius 5 is approximately 78.54 square units." - - def test_extract_reasoning_no_reasoning_present(self): - """Test that no reasoning is extracted when none is present.""" - response_text = "This is a simple response without any reasoning content." - - reasoning, cleaned = _extract_reasoning_from_text(response_text) - - assert reasoning is None - assert cleaned == response_text - - class TestCohereChatGeneratorReasoning: """Integration tests for reasoning functionality in CohereChatGenerator.""" From 309fbf2961ed93e5e60a10098001b43b8e05d9b9 Mon Sep 17 00:00:00 2001 From: Chinmay Bansal Date: Tue, 18 Nov 2025 23:21:30 -0800 Subject: [PATCH 5/5] fix test case and add streamign support for reasoning content --- .../generators/cohere/chat/chat_generator.py | 13 +++- .../cohere/tests/test_chat_generator.py | 66 ++++++++++++++++--- 2 files changed, 69 insertions(+), 10 deletions(-) diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index 47aaac8718..593dbeb292 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -282,11 +282,19 @@ def _convert_cohere_chunk_to_streaming_chunk( start = False finish_reason = None tool_calls = None + reasoning = None meta: dict[str, Any] = {"model": model} if chunk.type == "content-delta" and chunk.delta and chunk.delta.message: - if chunk.delta.message and chunk.delta.message.content and chunk.delta.message.content.text is not None: - content = chunk.delta.message.content.text + if chunk.delta.message and chunk.delta.message.content: + # Handle thinking/reasoning content (prioritize over text if both exist) + # Note: StreamingChunk can only have ONE of content/tool_calls/reasoning set + thinking_text = getattr(chunk.delta.message.content, "thinking", None) + if thinking_text is not None and isinstance(thinking_text, str): + reasoning = ReasoningContent(reasoning_text=thinking_text) + # Handle text content (only if no reasoning) + elif chunk.delta.message.content.text is not None: + content = chunk.delta.message.content.text elif chunk.type == "tool-plan-delta" and chunk.delta and chunk.delta.message: if chunk.delta.message and chunk.delta.message.tool_plan is not None: @@ -362,6 +370,7 @@ def _convert_cohere_chunk_to_streaming_chunk( component_info=component_info, index=index, tool_calls=tool_calls, + reasoning=reasoning, start=start, finish_reason=finish_reason, meta=meta, diff --git a/integrations/cohere/tests/test_chat_generator.py b/integrations/cohere/tests/test_chat_generator.py index 74153fa870..f4e57192d4 100644 --- a/integrations/cohere/tests/test_chat_generator.py +++ b/integrations/cohere/tests/test_chat_generator.py @@ -728,12 +728,6 @@ def test_live_run_multimodal(self): assert isinstance(results["replies"][0], ChatMessage) assert len(results["replies"][0].text) > 0 - -class TestCohereChatGeneratorReasoning: - """Integration tests for reasoning functionality in CohereChatGenerator.""" - - @pytest.mark.skipif(not os.environ.get("COHERE_API_KEY"), reason="COHERE_API_KEY not set") - @pytest.mark.integration def test_reasoning_with_command_a_reasoning_model(self): """Test reasoning extraction with Command A Reasoning model.""" generator = CohereChatGenerator( @@ -879,8 +873,64 @@ def test_reasoning_with_tool_calls_compatibility(self): # Check tool plan is used as text assert "I'll check the weather in Paris" in reply.text - @pytest.mark.skipif(not os.environ.get("COHERE_API_KEY"), reason="COHERE_API_KEY not set") - @pytest.mark.integration + def test_streaming_reasoning_with_mock_chunks(self): + """Test that reasoning content is captured during streaming.""" + generator = CohereChatGenerator( + model="command-a-reasoning-111b-2024-10-03", api_key=Secret.from_token("fake-api-key") + ) + + # Mock streaming chunks with thinking content + thinking_chunk = MagicMock() + thinking_chunk.type = "content-delta" + thinking_chunk.delta.message.content.thinking = "Let me calculate the area step by step. " + thinking_chunk.delta.message.content.text = None + + thinking_chunk2 = MagicMock() + thinking_chunk2.type = "content-delta" + thinking_chunk2.delta.message.content.thinking = "Using formula A = πr²." + thinking_chunk2.delta.message.content.text = None + + text_chunk = MagicMock() + text_chunk.type = "content-delta" + text_chunk.delta.message.content.text = "The area is " + text_chunk.delta.message.content.thinking = None + + text_chunk2 = MagicMock() + text_chunk2.type = "content-delta" + text_chunk2.delta.message.content.text = "78.54 square units." + text_chunk2.delta.message.content.thinking = None + + end_chunk = MagicMock() + end_chunk.type = "message-end" + end_chunk.delta.finish_reason = "COMPLETE" + end_chunk.delta.usage = None + + mock_stream = iter([thinking_chunk, thinking_chunk2, text_chunk, text_chunk2, end_chunk]) + + generator.client.chat_stream = MagicMock(return_value=mock_stream) + + messages = [ChatMessage.from_user("What is the area of a circle with radius 5?")] + streaming_chunks = [] + + def callback(chunk: StreamingChunk): + streaming_chunks.append(chunk) + + result = generator.run(messages=messages, streaming_callback=callback) + + # Verify streaming chunks include reasoning + thinking_chunks = [c for c in streaming_chunks if c.reasoning is not None] + assert len(thinking_chunks) == 2 + assert "step by step" in thinking_chunks[0].reasoning.reasoning_text + assert "formula" in thinking_chunks[1].reasoning.reasoning_text.lower() + + # Verify final message has reasoning + assert "replies" in result + reply = result["replies"][0] + assert reply.reasoning is not None + assert "step by step" in reply.reasoning.reasoning_text + assert "formula" in reply.reasoning.reasoning_text.lower() + assert reply.text == "The area is 78.54 square units." + def test_live_run_with_mixed_tools(self): """ Integration test that verifies CohereChatGenerator works with mixed Tool and Toolset.