Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "codealive-mcp"
version = "0.3.0"
version = "0.4.0"
description = "MCP server for the CodeAlive API"
readme = "README.md"
requires-python = ">=3.11"
Expand Down
6 changes: 3 additions & 3 deletions src/codealive_mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@
- Remember that context from previous messages is maintained in the same conversation

Flexible data source usage:
- You can use a workspace ID as a single data source to search or chat across all its repositories at once
- Alternatively, you can use specific repository IDs for more targeted searches
- For complex queries, you can combine multiple repository IDs from different workspaces
- You can use a workspace name as a single data source to search or chat across all its repositories at once
- Alternatively, you can use specific repository names for more targeted searches
- For complex queries, you can combine multiple repository names from different workspaces
- Choose between workspace-level or repository-level access based on the scope of the query

Repository integration:
Expand Down
32 changes: 16 additions & 16 deletions src/tests/test_chat_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

@pytest.mark.asyncio
@patch('tools.chat.get_api_key_from_context')
async def test_consultant_with_simple_ids(mock_get_api_key):
"""Test codebase consultant with simple string IDs."""
async def test_consultant_with_simple_names(mock_get_api_key):
"""Test codebase consultant with simple string names."""
mock_get_api_key.return_value = "test_key"

ctx = MagicMock(spec=Context)
Expand Down Expand Up @@ -39,7 +39,7 @@ async def mock_aiter_lines():

ctx.request_context.lifespan_context = mock_codealive_context

# Test with simple string IDs
# Test with simple string names
result = await codebase_consultant(
ctx=ctx,
question="Test question",
Expand All @@ -50,19 +50,19 @@ async def mock_aiter_lines():
call_args = mock_client.post.call_args
request_data = call_args.kwargs["json"]

# Should convert simple IDs to {"id": "..."} format
assert request_data["dataSources"] == [
{"id": "repo123"},
{"id": "repo456"}
# Should convert simple names to the backend names array
assert request_data["names"] == [
"repo123",
"repo456"
]

assert result == "Hello world"


@pytest.mark.asyncio
@patch('tools.chat.get_api_key_from_context')
async def test_consultant_preserves_string_ids(mock_get_api_key):
"""Test codebase consultant preserves string IDs."""
async def test_consultant_preserves_string_names(mock_get_api_key):
"""Test codebase consultant preserves string names."""
mock_get_api_key.return_value = "test_key"

ctx = MagicMock(spec=Context)
Expand All @@ -88,7 +88,7 @@ async def mock_aiter_lines():

ctx.request_context.lifespan_context = mock_codealive_context

# Test with string IDs
# Test with string names
result = await codebase_consultant(
ctx=ctx,
question="Test",
Expand All @@ -98,10 +98,10 @@ async def mock_aiter_lines():
call_args = mock_client.post.call_args
request_data = call_args.kwargs["json"]

# Should extract just the ID
assert request_data["dataSources"] == [
{"id": "repo123"},
{"id": "repo456"}
# Should extract just the normalized names
assert request_data["names"] == [
"repo123",
"repo456"
]

assert result == "Response"
Expand Down Expand Up @@ -145,8 +145,8 @@ async def mock_aiter_lines():

# Should include conversation ID
assert request_data["conversationId"] == "conv_123"
# Should not have data sources when continuing conversation
assert "dataSources" not in request_data
# Should not have explicit names when continuing conversation
assert "names" not in request_data

assert result == "Continued"

Expand Down
50 changes: 20 additions & 30 deletions src/tests/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from unittest.mock import AsyncMock, MagicMock
import httpx
from utils.errors import handle_api_error, format_data_source_ids
from utils.errors import handle_api_error, format_data_source_names


@pytest.mark.asyncio
Expand Down Expand Up @@ -109,57 +109,47 @@ async def test_handle_unknown_http_error():
assert len(result) < 300


def test_format_data_source_ids_strings():
"""Test formatting simple string IDs."""
def test_format_data_source_names_strings():
"""Test formatting simple string names."""
input_data = ["id1", "id2", "id3"]
result = format_data_source_ids(input_data)
result = format_data_source_names(input_data)

assert result == [
{"id": "id1"},
{"id": "id2"},
{"id": "id3"}
]
assert result == ["id1", "id2", "id3"]


def test_format_data_source_ids_dicts():
"""Test formatting dictionary IDs."""
def test_format_data_source_names_dicts():
"""Test formatting dictionary inputs."""
input_data = [
{"id": "id1"},
{"type": "repository", "id": "id2"},
{"name": "repo-name"},
{"id": "id3", "extra": "field"}
]
result = format_data_source_ids(input_data)
result = format_data_source_names(input_data)

assert result == [
{"id": "id1"},
{"id": "id2"},
{"id": "id3"}
]
assert result == ["id1", "id2", "repo-name", "id3"]


def test_format_data_source_ids_mixed():
"""Test formatting mixed format IDs."""
def test_format_data_source_names_mixed():
"""Test formatting mixed format inputs."""
input_data = [
"id1",
{"id": "id2"},
{"type": "workspace", "id": "id3"},
"", # Empty string - should be skipped
None, # None - should be skipped
{"no_id": "field"}, # Missing id - should be skipped
{"name": "repo-name"},
"id4"
]
result = format_data_source_ids(input_data)
result = format_data_source_names(input_data)

assert result == [
{"id": "id1"},
{"id": "id2"},
{"id": "id3"},
{"id": "id4"}
]
assert result == ["id1", "id2", "id3", "repo-name", "id4"]


def test_format_data_source_ids_empty():
def test_format_data_source_names_empty():
"""Test formatting empty/None inputs."""
assert format_data_source_ids(None) == []
assert format_data_source_ids([]) == []
assert format_data_source_ids([None, "", {}]) == []
assert format_data_source_names(None) == []
assert format_data_source_names([]) == []
assert format_data_source_names([None, "", {}]) == []

63 changes: 36 additions & 27 deletions src/tests/test_parameter_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,41 @@

import pytest
import json
from utils.errors import normalize_data_source_ids
from utils.errors import normalize_data_source_names


class TestNormalizeDataSourceIds:
"""Test the normalize_data_source_ids function with various input formats."""
class TestNormalizeDataSourceNames:
"""Test the normalize_data_source_names function with various input formats."""

def test_proper_array_input(self):
"""Test that proper arrays are passed through unchanged."""
input_data = ["repo1", "repo2", "repo3"]
result = normalize_data_source_ids(input_data)
result = normalize_data_source_names(input_data)
assert result == ["repo1", "repo2", "repo3"]

def test_single_string_input(self):
"""Test that single string is converted to array."""
input_data = "repo1"
result = normalize_data_source_ids(input_data)
result = normalize_data_source_names(input_data)
assert result == ["repo1"]

def test_json_encoded_string_input(self):
"""Test that JSON-encoded strings are properly parsed."""
input_data = '["repo1", "repo2"]'
result = normalize_data_source_ids(input_data)
result = normalize_data_source_names(input_data)
assert result == ["repo1", "repo2"]

def test_malformed_json_string_fallback(self):
"""Test that malformed JSON strings fall back to single ID."""
input_data = '["repo1", "repo2"' # Missing closing bracket
result = normalize_data_source_ids(input_data)
result = normalize_data_source_names(input_data)
assert result == ['["repo1", "repo2"'] # Treated as single ID

def test_empty_inputs(self):
"""Test various empty input types."""
assert normalize_data_source_ids(None) == []
assert normalize_data_source_ids("") == []
assert normalize_data_source_ids([]) == []
assert normalize_data_source_names(None) == []
assert normalize_data_source_names("") == []
assert normalize_data_source_names([]) == []

def test_mixed_array_with_dicts(self):
"""Test arrays containing both strings and dict objects."""
Expand All @@ -46,61 +46,70 @@ def test_mixed_array_with_dicts(self):
"repo3",
{"id": "workspace1", "type": "workspace"}
]
result = normalize_data_source_ids(input_data)
result = normalize_data_source_names(input_data)
assert result == ["repo1", "repo2", "repo3", "workspace1"]

def test_dict_without_id(self):
"""Test that dicts without 'id' field are skipped."""
"""Test that dicts without 'id' field use 'name' field if present."""
input_data = [
"repo1",
{"name": "some-repo", "type": "repository"}, # No 'id' field
{"name": "some-repo", "type": "repository"}, # No 'id' field, but has 'name'
"repo2"
]
result = normalize_data_source_ids(input_data)
assert result == ["repo1", "repo2"]
result = normalize_data_source_names(input_data)
assert result == ["repo1", "some-repo", "repo2"]

def test_empty_strings_preserved(self):
"""Test that empty strings in arrays are preserved (might be intentional)."""
input_data = ["repo1", "", "repo2", " ", "repo3"]
result = normalize_data_source_ids(input_data)
result = normalize_data_source_names(input_data)
assert result == ["repo1", "", "repo2", " ", "repo3"] # All strings preserved

def test_non_list_non_string_input(self):
"""Test handling of unexpected input types."""
result = normalize_data_source_ids(123)
result = normalize_data_source_names(123)
assert result == ["123"]

result = normalize_data_source_ids({"id": "repo1"})
result = normalize_data_source_names({"id": "repo1"})
assert result == ["{'id': 'repo1'}"]

def test_claude_desktop_scenarios(self):
"""Test specific scenarios from Claude Desktop serialization issues."""
# Scenario 1: JSON string as seen in Claude Desktop logs
claude_input_1 = '["67db4097fa23c0a98a8495c2"]'
result_1 = normalize_data_source_ids(claude_input_1)
result_1 = normalize_data_source_names(claude_input_1)
assert result_1 == ["67db4097fa23c0a98a8495c2"]

# Scenario 2: Plain string as seen in Claude Desktop logs
claude_input_2 = "67db4097fa23c0a98a8495c2"
result_2 = normalize_data_source_ids(claude_input_2)
result_2 = normalize_data_source_names(claude_input_2)
assert result_2 == ["67db4097fa23c0a98a8495c2"]

# Scenario 3: Multiple IDs in JSON string
claude_input_3 = '["repo1", "repo2", "workspace1"]'
result_3 = normalize_data_source_ids(claude_input_3)
result_3 = normalize_data_source_names(claude_input_3)
assert result_3 == ["repo1", "repo2", "workspace1"]

def test_edge_cases(self):
"""Test various edge cases."""
# Whitespace-only JSON string
assert normalize_data_source_ids("[]") == []
assert normalize_data_source_ids("[ ]") == []
assert normalize_data_source_names("[]") == []
assert normalize_data_source_names("[ ]") == []

# Single item JSON array
assert normalize_data_source_ids('["single"]') == ["single"]
assert normalize_data_source_names('["single"]') == ["single"]

# JSON array with empty strings
assert normalize_data_source_ids('["repo1", "", "repo2"]') == ["repo1", "", "repo2"]
assert normalize_data_source_names('["repo1", "", "repo2"]') == ["repo1", "", "repo2"]

def test_dict_with_name_preferred(self):
"""Dict inputs with explicit names should take precedence over IDs."""
input_data = [
{"id": "legacy-id", "name": "repo-main"},
{"name": "workspace:analytics"}
]
result = normalize_data_source_names(input_data)
assert result == ["repo-main", "workspace:analytics"]


class TestParameterNormalizationIntegration:
Expand All @@ -113,10 +122,10 @@ def test_search_tool_parameter_handling(self):

# Verify the function accepts Union[str, List[str]]
sig = inspect.signature(codebase_search)
data_source_ids_param = sig.parameters['data_source_ids']
data_sources_param = sig.parameters['data_sources']

# The annotation should accept both str and List[str]
assert 'Union' in str(data_source_ids_param.annotation) or 'str' in str(data_source_ids_param.annotation)
assert 'Union' in str(data_sources_param.annotation) or 'str' in str(data_sources_param.annotation)

def test_consultant_tool_parameter_handling(self):
"""Test that consultant tool properly normalizes various parameter formats."""
Expand Down
2 changes: 1 addition & 1 deletion src/tests/test_response_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def test_data_preservation_with_content(self):
"range": {"start": {"line": 18}, "end": {"line": 168}}
},
"score": 0.99,
"content": "async def codebase_search(\n ctx: Context,\n query: str,\n data_source_ids: Optional[List[str]] = None,\n mode: str = \"auto\",\n include_content: bool = False\n) -> Dict:",
"content": "async def codebase_search(\n ctx: Context,\n query: str,\n data_sources: Optional[List[str]] = None,\n mode: str = \"auto\",\n include_content: bool = False\n) -> Dict:",
"dataSource": {
"type": "repository",
"id": "685b21230e3822f4efa9d073",
Expand Down
11 changes: 8 additions & 3 deletions src/tests/test_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def test_codebase_search_returns_xml_string(mock_get_api_key):
result = await codebase_search(
ctx=ctx,
query="authenticate_user",
data_source_ids=["test_id"],
data_sources=["test-name"],
mode="auto",
include_content=False
)
Expand All @@ -63,6 +63,11 @@ async def test_codebase_search_returns_xml_string(mock_get_api_key):
assert "<results>" in result, "Should contain results tag"
assert "<search_result" in result, "Should contain search_result tag"

# Verify the request used the Names query parameter
call_args = mock_client.get.call_args
params = call_args.kwargs["params"]
assert ("Names", "test-name") in params


@pytest.mark.asyncio
async def test_codebase_search_empty_query_returns_error_string():
Expand All @@ -82,7 +87,7 @@ async def test_codebase_search_empty_query_returns_error_string():
result = await codebase_search(
ctx=ctx,
query="",
data_source_ids=["test_id"],
data_sources=["test-name"],
mode="auto",
include_content=False
)
Expand Down Expand Up @@ -138,7 +143,7 @@ def raise_404():
result = await codebase_search(
ctx=ctx,
query="test query",
data_source_ids=["invalid_id"],
data_sources=["invalid-name"],
mode="auto",
include_content=False
)
Expand Down
Loading