diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py index 020aaa41fd..b7b77b60b2 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py @@ -22,14 +22,16 @@ class CohereDocumentEmbedder: Usage example: ```python from haystack import Document - from haystack_integrations.components.embedders.cohere import CohereDocumentEmbedder + from haystack_integrations.components.embedders.cohere import ( + CohereDocumentEmbedder, + ) doc = Document(content="I love pizza!") document_embedder = CohereDocumentEmbedder() result = document_embedder.run([doc]) - print(result['documents'][0].embedding) + print(result["documents"][0].embedding) # [-0.453125, 1.2236328, 2.0058594, ...] ``` diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_image_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_image_embedder.py index 1484365819..7b5316eb99 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_image_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_image_embedder.py @@ -37,7 +37,9 @@ class CohereDocumentImageEmbedder: ### Usage example ```python from haystack import Document - from haystack_integrations.components.embedders.cohere import CohereDocumentImageEmbedder + from haystack_integrations.components.embedders.cohere import ( + CohereDocumentImageEmbedder, + ) embedder = CohereDocumentImageEmbedder(model="embed-v4.0") diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index 6c01c81cf5..593dbeb292 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -4,7 +4,7 @@ from haystack import component, default_from_dict, default_to_dict, logging from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message -from haystack.dataclasses import ChatMessage, ComponentInfo, ImageContent, TextContent, ToolCall +from haystack.dataclasses import ChatMessage, ComponentInfo, ImageContent, ReasoningContent, TextContent, ToolCall from haystack.dataclasses.streaming_chunk import ( AsyncStreamingCallbackT, FinishReason, @@ -184,6 +184,7 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage: Extracts and organizes various response components including: - Text content - Tool calls + - Reasoning content (via native Cohere API) - Usage statistics - Citations - Metadata @@ -192,6 +193,21 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage: :param model: The name of the model that generated the response. :return: A Haystack ChatMessage containing the formatted response. """ + # Extract reasoning content using Cohere's native API + reasoning_content = None + text_content = "" + + if chat_response.message.content: + for content_item in chat_response.message.content: + # Access thinking content via native Cohere API + if hasattr(content_item, "type") and content_item.type == "thinking": + if hasattr(content_item, "thinking") and content_item.thinking: + reasoning_content = ReasoningContent(reasoning_text=content_item.thinking) + # Access text content + elif hasattr(content_item, "type") and content_item.type == "text": + if hasattr(content_item, "text") and content_item.text: + text_content = content_item.text + if chat_response.message.tool_calls: tool_calls = [] for tc in chat_response.message.tool_calls: @@ -206,9 +222,9 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage: # Create message with tool plan as text and tool calls in the format Haystack expects tool_plan = chat_response.message.tool_plan or "" - message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls) - elif chat_response.message.content and hasattr(chat_response.message.content[0], "text"): - message = ChatMessage.from_assistant(chat_response.message.content[0].text) + message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls, reasoning=reasoning_content) + elif text_content: + message = ChatMessage.from_assistant(text_content, reasoning=reasoning_content) else: # Handle the case where neither tool_calls nor content exists logger.warning(f"Received empty response from Cohere API: {chat_response.message}") @@ -266,11 +282,19 @@ def _convert_cohere_chunk_to_streaming_chunk( start = False finish_reason = None tool_calls = None + reasoning = None meta: dict[str, Any] = {"model": model} if chunk.type == "content-delta" and chunk.delta and chunk.delta.message: - if chunk.delta.message and chunk.delta.message.content and chunk.delta.message.content.text is not None: - content = chunk.delta.message.content.text + if chunk.delta.message and chunk.delta.message.content: + # Handle thinking/reasoning content (prioritize over text if both exist) + # Note: StreamingChunk can only have ONE of content/tool_calls/reasoning set + thinking_text = getattr(chunk.delta.message.content, "thinking", None) + if thinking_text is not None and isinstance(thinking_text, str): + reasoning = ReasoningContent(reasoning_text=thinking_text) + # Handle text content (only if no reasoning) + elif chunk.delta.message.content.text is not None: + content = chunk.delta.message.content.text elif chunk.type == "tool-plan-delta" and chunk.delta and chunk.delta.message: if chunk.delta.message and chunk.delta.message.tool_plan is not None: @@ -346,6 +370,7 @@ def _convert_cohere_chunk_to_streaming_chunk( component_info=component_info, index=index, tool_calls=tool_calls, + reasoning=reasoning, start=start, finish_reason=finish_reason, meta=meta, @@ -437,9 +462,13 @@ class CohereChatGenerator: ```python from haystack.dataclasses import ChatMessage from haystack.utils import Secret - from haystack_integrations.components.generators.cohere import CohereChatGenerator + from haystack_integrations.components.generators.cohere import ( + CohereChatGenerator, + ) - client = CohereChatGenerator(model="command-r-08-2024", api_key=Secret.from_env_var("COHERE_API_KEY")) + client = CohereChatGenerator( + model="command-r-08-2024", api_key=Secret.from_env_var("COHERE_API_KEY") + ) messages = [ChatMessage.from_user("What's Natural Language Processing?")] client.run(messages) @@ -451,16 +480,25 @@ class CohereChatGenerator: ```python from haystack.dataclasses import ChatMessage, ImageContent from haystack.utils import Secret - from haystack_integrations.components.generators.cohere import CohereChatGenerator + from haystack_integrations.components.generators.cohere import ( + CohereChatGenerator, + ) # Create an image from file path or base64 image_content = ImageContent.from_file_path("path/to/your/image.jpg") # Create a multimodal message with both text and image - messages = [ChatMessage.from_user(content_parts=["What's in this image?", image_content])] + messages = [ + ChatMessage.from_user( + content_parts=["What's in this image?", image_content] + ) + ] # Use a multimodal model like Command A Vision - client = CohereChatGenerator(model="command-a-vision-07-2025", api_key=Secret.from_env_var("COHERE_API_KEY")) + client = CohereChatGenerator( + model="command-a-vision-07-2025", + api_key=Secret.from_env_var("COHERE_API_KEY"), + ) response = client.run(messages) print(response) ``` @@ -475,12 +513,16 @@ class CohereChatGenerator: from haystack.dataclasses import ChatMessage from haystack.components.tools import ToolInvoker from haystack.tools import Tool - from haystack_integrations.components.generators.cohere import CohereChatGenerator + from haystack_integrations.components.generators.cohere import ( + CohereChatGenerator, + ) + # Create a weather tool def weather(city: str) -> str: return f"The weather in {city} is sunny and 32°C" + weather_tool = Tool( name="weather", description="useful to determine the weather in a given location", @@ -499,13 +541,22 @@ def weather(city: str) -> str: # Create and set up the pipeline pipeline = Pipeline() - pipeline.add_component("generator", CohereChatGenerator(model="command-r-08-2024", tools=[weather_tool])) + pipeline.add_component( + "generator", + CohereChatGenerator(model="command-r-08-2024", tools=[weather_tool]), + ) pipeline.add_component("tool_invoker", ToolInvoker(tools=[weather_tool])) pipeline.connect("generator", "tool_invoker") # Run the pipeline with a weather query results = pipeline.run( - data={"generator": {"messages": [ChatMessage.from_user("What's the weather like in Paris?")]}} + data={ + "generator": { + "messages": [ + ChatMessage.from_user("What's the weather like in Paris?") + ] + } + } ) # The tool result will be available in the pipeline output diff --git a/integrations/cohere/tests/test_chat_generator.py b/integrations/cohere/tests/test_chat_generator.py index d3ddc7fb15..f4e57192d4 100644 --- a/integrations/cohere/tests/test_chat_generator.py +++ b/integrations/cohere/tests/test_chat_generator.py @@ -7,15 +7,13 @@ from haystack import Pipeline from haystack.components.generators.utils import print_streaming_chunk from haystack.components.tools import ToolInvoker -from haystack.dataclasses import ChatMessage, ChatRole, ImageContent, ToolCall +from haystack.dataclasses import ChatMessage, ChatRole, ImageContent, ReasoningContent, ToolCall from haystack.dataclasses.streaming_chunk import StreamingChunk from haystack.tools import Tool, Toolset from haystack.utils import Secret from haystack_integrations.components.generators.cohere import CohereChatGenerator -from haystack_integrations.components.generators.cohere.chat.chat_generator import ( - _format_message, -) +from haystack_integrations.components.generators.cohere.chat.chat_generator import _format_message def weather(city: str) -> str: @@ -444,11 +442,14 @@ def test_run_image(self): generator = CohereChatGenerator(api_key=Secret.from_token("test-api-key")) - # Mock the client's chat method + # Mock the client's chat method with proper content structure mock_response = MagicMock() - mock_response.message.content = [MagicMock()] - mock_response.message.content[0].text = "This is a test image response" + text_content = MagicMock() + text_content.type = "text" + text_content.text = "This is a test image response" + mock_response.message.content = [text_content] mock_response.message.tool_calls = None + mock_response.message.citations = None mock_response.finish_reason = "COMPLETE" mock_response.usage = None @@ -727,6 +728,209 @@ def test_live_run_multimodal(self): assert isinstance(results["replies"][0], ChatMessage) assert len(results["replies"][0].text) > 0 + def test_reasoning_with_command_a_reasoning_model(self): + """Test reasoning extraction with Command A Reasoning model.""" + generator = CohereChatGenerator( + model="command-a-reasoning-111b-2024-10-03", + generation_kwargs={"thinking": True}, # Enable reasoning + ) + + messages = [ + ChatMessage.from_user("Solve this math problem step by step: What is the area of a circle with radius 7?") + ] + + result = generator.run(messages=messages) + + assert "replies" in result + assert len(result["replies"]) == 1 + + reply = result["replies"][0] + assert isinstance(reply, ChatMessage) + assert reply.role == ChatRole.ASSISTANT + + # Check if reasoning was extracted + if reply.reasoning: + assert isinstance(reply.reasoning, ReasoningContent) + assert len(reply.reasoning.reasoning_text) > 50 # Should have substantial reasoning + + # The reasoning should contain mathematical thinking + reasoning_lower = reply.reasoning.reasoning_text.lower() + assert any(word in reasoning_lower for word in ["area", "circle", "radius", "formula", "π", "pi"]) + + # Check the main response content + assert len(reply.text) > 0 + response_lower = reply.text.lower() + assert any(word in response_lower for word in ["area", "153.94", "154", "square"]) + + def test_reasoning_with_mock_response(self): + """Test reasoning extraction with mocked Cohere response using native API.""" + generator = CohereChatGenerator( + model="command-a-reasoning-111b-2024-10-03", api_key=Secret.from_token("fake-api-key") + ) + + # Mock the Cohere client response using native API structure + mock_response = MagicMock() + + # Create mock content items with thinking and text types + thinking_content = MagicMock() + thinking_content.type = "thinking" + thinking_content.thinking = """I need to solve for the area of a circle. +The formula is A = πr² +With radius 7: A = π * 7² = π * 49 ≈ 153.94""" + + text_content = MagicMock() + text_content.type = "text" + text_content.text = "The area of a circle with radius 7 is approximately 153.94 square units." + + mock_response.message.content = [thinking_content, text_content] + mock_response.message.tool_calls = None + mock_response.message.citations = None + mock_response.finish_reason = "COMPLETE" + mock_response.usage = None + + generator.client.chat = MagicMock(return_value=mock_response) + + messages = [ChatMessage.from_user("What is the area of a circle with radius 7?")] + result = generator.run(messages=messages) + + assert "replies" in result + assert len(result["replies"]) == 1 + + reply = result["replies"][0] + assert isinstance(reply, ChatMessage) + assert reply.role == ChatRole.ASSISTANT + + # Check reasoning extraction via native API + assert reply.reasoning is not None + assert isinstance(reply.reasoning, ReasoningContent) + assert "formula is A = πr²" in reply.reasoning.reasoning_text + assert "π * 49 ≈ 153.94" in reply.reasoning.reasoning_text + + # Check text content + assert reply.text.strip() == "The area of a circle with radius 7 is approximately 153.94 square units." + + def test_reasoning_with_tool_calls_compatibility(self): + """Test that reasoning works with tool calls.""" + weather_tool = Tool( + name="weather", + description="Get weather for a city", + parameters={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + function=weather, + ) + + generator = CohereChatGenerator( + model="command-a-reasoning-111b-2024-10-03", tools=[weather_tool], api_key=Secret.from_token("fake-api-key") + ) + + # Mock response with both reasoning and tool calls using native API + mock_response = MagicMock() + + # Create mock content items with thinking type + thinking_content = MagicMock() + thinking_content.type = "thinking" + thinking_content.thinking = ( + "The user is asking about weather in Paris. I should use the weather tool to get accurate information." + ) + + mock_response.message.content = [thinking_content] + + # Mock tool call + mock_tool_call = MagicMock() + mock_tool_call.function.name = "weather" + mock_tool_call.function.arguments = '{"city": "Paris"}' + mock_tool_call.id = "call_123" + mock_response.message.tool_calls = [mock_tool_call] + mock_response.message.tool_plan = "I'll check the weather in Paris for you." + mock_response.message.citations = None + mock_response.finish_reason = "TOOL_CALLS" + mock_response.usage = None + + generator.client.chat = MagicMock(return_value=mock_response) + + messages = [ChatMessage.from_user("What's the weather like in Paris?")] + result = generator.run(messages=messages) + + assert "replies" in result + assert len(result["replies"]) == 1 + + reply = result["replies"][0] + assert isinstance(reply, ChatMessage) + + # Check reasoning extraction via native API + assert reply.reasoning is not None + assert isinstance(reply.reasoning, ReasoningContent) + assert "weather tool" in reply.reasoning.reasoning_text + + # Check tool calls are preserved + assert reply.tool_calls is not None + assert len(reply.tool_calls) == 1 + assert reply.tool_calls[0].tool_name == "weather" + + # Check tool plan is used as text + assert "I'll check the weather in Paris" in reply.text + + def test_streaming_reasoning_with_mock_chunks(self): + """Test that reasoning content is captured during streaming.""" + generator = CohereChatGenerator( + model="command-a-reasoning-111b-2024-10-03", api_key=Secret.from_token("fake-api-key") + ) + + # Mock streaming chunks with thinking content + thinking_chunk = MagicMock() + thinking_chunk.type = "content-delta" + thinking_chunk.delta.message.content.thinking = "Let me calculate the area step by step. " + thinking_chunk.delta.message.content.text = None + + thinking_chunk2 = MagicMock() + thinking_chunk2.type = "content-delta" + thinking_chunk2.delta.message.content.thinking = "Using formula A = πr²." + thinking_chunk2.delta.message.content.text = None + + text_chunk = MagicMock() + text_chunk.type = "content-delta" + text_chunk.delta.message.content.text = "The area is " + text_chunk.delta.message.content.thinking = None + + text_chunk2 = MagicMock() + text_chunk2.type = "content-delta" + text_chunk2.delta.message.content.text = "78.54 square units." + text_chunk2.delta.message.content.thinking = None + + end_chunk = MagicMock() + end_chunk.type = "message-end" + end_chunk.delta.finish_reason = "COMPLETE" + end_chunk.delta.usage = None + + mock_stream = iter([thinking_chunk, thinking_chunk2, text_chunk, text_chunk2, end_chunk]) + + generator.client.chat_stream = MagicMock(return_value=mock_stream) + + messages = [ChatMessage.from_user("What is the area of a circle with radius 5?")] + streaming_chunks = [] + + def callback(chunk: StreamingChunk): + streaming_chunks.append(chunk) + + result = generator.run(messages=messages, streaming_callback=callback) + + # Verify streaming chunks include reasoning + thinking_chunks = [c for c in streaming_chunks if c.reasoning is not None] + assert len(thinking_chunks) == 2 + assert "step by step" in thinking_chunks[0].reasoning.reasoning_text + assert "formula" in thinking_chunks[1].reasoning.reasoning_text.lower() + + # Verify final message has reasoning + assert "replies" in result + reply = result["replies"][0] + assert reply.reasoning is not None + assert "step by step" in reply.reasoning.reasoning_text + assert "formula" in reply.reasoning.reasoning_text.lower() + assert reply.text == "The area is 78.54 square units." + def test_live_run_with_mixed_tools(self): """ Integration test that verifies CohereChatGenerator works with mixed Tool and Toolset.