diff --git a/pyproject.toml b/pyproject.toml index 03db6ff..5106f3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "codealive-mcp" -version = "0.3.0" +version = "0.4.0" description = "MCP server for the CodeAlive API" readme = "README.md" requires-python = ">=3.11" diff --git a/src/codealive_mcp_server.py b/src/codealive_mcp_server.py index 0e81b7e..9ea9e88 100644 --- a/src/codealive_mcp_server.py +++ b/src/codealive_mcp_server.py @@ -54,9 +54,9 @@ - Remember that context from previous messages is maintained in the same conversation Flexible data source usage: - - You can use a workspace ID as a single data source to search or chat across all its repositories at once - - Alternatively, you can use specific repository IDs for more targeted searches - - For complex queries, you can combine multiple repository IDs from different workspaces + - You can use a workspace name as a single data source to search or chat across all its repositories at once + - Alternatively, you can use specific repository names for more targeted searches + - For complex queries, you can combine multiple repository names from different workspaces - Choose between workspace-level or repository-level access based on the scope of the query Repository integration: diff --git a/src/tests/test_chat_tool.py b/src/tests/test_chat_tool.py index 6d7ad74..762bb2a 100644 --- a/src/tests/test_chat_tool.py +++ b/src/tests/test_chat_tool.py @@ -9,8 +9,8 @@ @pytest.mark.asyncio @patch('tools.chat.get_api_key_from_context') -async def test_consultant_with_simple_ids(mock_get_api_key): - """Test codebase consultant with simple string IDs.""" +async def test_consultant_with_simple_names(mock_get_api_key): + """Test codebase consultant with simple string names.""" mock_get_api_key.return_value = "test_key" ctx = MagicMock(spec=Context) @@ -39,7 +39,7 @@ async def mock_aiter_lines(): ctx.request_context.lifespan_context = mock_codealive_context - # Test with simple string IDs + # Test with simple string names result = await codebase_consultant( ctx=ctx, question="Test question", @@ -50,10 +50,10 @@ async def mock_aiter_lines(): call_args = mock_client.post.call_args request_data = call_args.kwargs["json"] - # Should convert simple IDs to {"id": "..."} format - assert request_data["dataSources"] == [ - {"id": "repo123"}, - {"id": "repo456"} + # Should convert simple names to the backend names array + assert request_data["names"] == [ + "repo123", + "repo456" ] assert result == "Hello world" @@ -61,8 +61,8 @@ async def mock_aiter_lines(): @pytest.mark.asyncio @patch('tools.chat.get_api_key_from_context') -async def test_consultant_preserves_string_ids(mock_get_api_key): - """Test codebase consultant preserves string IDs.""" +async def test_consultant_preserves_string_names(mock_get_api_key): + """Test codebase consultant preserves string names.""" mock_get_api_key.return_value = "test_key" ctx = MagicMock(spec=Context) @@ -88,7 +88,7 @@ async def mock_aiter_lines(): ctx.request_context.lifespan_context = mock_codealive_context - # Test with string IDs + # Test with string names result = await codebase_consultant( ctx=ctx, question="Test", @@ -98,10 +98,10 @@ async def mock_aiter_lines(): call_args = mock_client.post.call_args request_data = call_args.kwargs["json"] - # Should extract just the ID - assert request_data["dataSources"] == [ - {"id": "repo123"}, - {"id": "repo456"} + # Should extract just the normalized names + assert request_data["names"] == [ + "repo123", + "repo456" ] assert result == "Response" @@ -145,8 +145,8 @@ async def mock_aiter_lines(): # Should include conversation ID assert request_data["conversationId"] == "conv_123" - # Should not have data sources when continuing conversation - assert "dataSources" not in request_data + # Should not have explicit names when continuing conversation + assert "names" not in request_data assert result == "Continued" diff --git a/src/tests/test_error_handling.py b/src/tests/test_error_handling.py index 5516faa..7e9dac3 100644 --- a/src/tests/test_error_handling.py +++ b/src/tests/test_error_handling.py @@ -3,7 +3,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock import httpx -from utils.errors import handle_api_error, format_data_source_ids +from utils.errors import handle_api_error, format_data_source_names @pytest.mark.asyncio @@ -109,36 +109,29 @@ async def test_handle_unknown_http_error(): assert len(result) < 300 -def test_format_data_source_ids_strings(): - """Test formatting simple string IDs.""" +def test_format_data_source_names_strings(): + """Test formatting simple string names.""" input_data = ["id1", "id2", "id3"] - result = format_data_source_ids(input_data) + result = format_data_source_names(input_data) - assert result == [ - {"id": "id1"}, - {"id": "id2"}, - {"id": "id3"} - ] + assert result == ["id1", "id2", "id3"] -def test_format_data_source_ids_dicts(): - """Test formatting dictionary IDs.""" +def test_format_data_source_names_dicts(): + """Test formatting dictionary inputs.""" input_data = [ {"id": "id1"}, {"type": "repository", "id": "id2"}, + {"name": "repo-name"}, {"id": "id3", "extra": "field"} ] - result = format_data_source_ids(input_data) + result = format_data_source_names(input_data) - assert result == [ - {"id": "id1"}, - {"id": "id2"}, - {"id": "id3"} - ] + assert result == ["id1", "id2", "repo-name", "id3"] -def test_format_data_source_ids_mixed(): - """Test formatting mixed format IDs.""" +def test_format_data_source_names_mixed(): + """Test formatting mixed format inputs.""" input_data = [ "id1", {"id": "id2"}, @@ -146,20 +139,17 @@ def test_format_data_source_ids_mixed(): "", # Empty string - should be skipped None, # None - should be skipped {"no_id": "field"}, # Missing id - should be skipped + {"name": "repo-name"}, "id4" ] - result = format_data_source_ids(input_data) + result = format_data_source_names(input_data) - assert result == [ - {"id": "id1"}, - {"id": "id2"}, - {"id": "id3"}, - {"id": "id4"} - ] + assert result == ["id1", "id2", "id3", "repo-name", "id4"] -def test_format_data_source_ids_empty(): +def test_format_data_source_names_empty(): """Test formatting empty/None inputs.""" - assert format_data_source_ids(None) == [] - assert format_data_source_ids([]) == [] - assert format_data_source_ids([None, "", {}]) == [] \ No newline at end of file + assert format_data_source_names(None) == [] + assert format_data_source_names([]) == [] + assert format_data_source_names([None, "", {}]) == [] + diff --git a/src/tests/test_parameter_normalization.py b/src/tests/test_parameter_normalization.py index 83cd08e..784896f 100644 --- a/src/tests/test_parameter_normalization.py +++ b/src/tests/test_parameter_normalization.py @@ -2,41 +2,41 @@ import pytest import json -from utils.errors import normalize_data_source_ids +from utils.errors import normalize_data_source_names -class TestNormalizeDataSourceIds: - """Test the normalize_data_source_ids function with various input formats.""" +class TestNormalizeDataSourceNames: + """Test the normalize_data_source_names function with various input formats.""" def test_proper_array_input(self): """Test that proper arrays are passed through unchanged.""" input_data = ["repo1", "repo2", "repo3"] - result = normalize_data_source_ids(input_data) + result = normalize_data_source_names(input_data) assert result == ["repo1", "repo2", "repo3"] def test_single_string_input(self): """Test that single string is converted to array.""" input_data = "repo1" - result = normalize_data_source_ids(input_data) + result = normalize_data_source_names(input_data) assert result == ["repo1"] def test_json_encoded_string_input(self): """Test that JSON-encoded strings are properly parsed.""" input_data = '["repo1", "repo2"]' - result = normalize_data_source_ids(input_data) + result = normalize_data_source_names(input_data) assert result == ["repo1", "repo2"] def test_malformed_json_string_fallback(self): """Test that malformed JSON strings fall back to single ID.""" input_data = '["repo1", "repo2"' # Missing closing bracket - result = normalize_data_source_ids(input_data) + result = normalize_data_source_names(input_data) assert result == ['["repo1", "repo2"'] # Treated as single ID def test_empty_inputs(self): """Test various empty input types.""" - assert normalize_data_source_ids(None) == [] - assert normalize_data_source_ids("") == [] - assert normalize_data_source_ids([]) == [] + assert normalize_data_source_names(None) == [] + assert normalize_data_source_names("") == [] + assert normalize_data_source_names([]) == [] def test_mixed_array_with_dicts(self): """Test arrays containing both strings and dict objects.""" @@ -46,61 +46,70 @@ def test_mixed_array_with_dicts(self): "repo3", {"id": "workspace1", "type": "workspace"} ] - result = normalize_data_source_ids(input_data) + result = normalize_data_source_names(input_data) assert result == ["repo1", "repo2", "repo3", "workspace1"] def test_dict_without_id(self): - """Test that dicts without 'id' field are skipped.""" + """Test that dicts without 'id' field use 'name' field if present.""" input_data = [ "repo1", - {"name": "some-repo", "type": "repository"}, # No 'id' field + {"name": "some-repo", "type": "repository"}, # No 'id' field, but has 'name' "repo2" ] - result = normalize_data_source_ids(input_data) - assert result == ["repo1", "repo2"] + result = normalize_data_source_names(input_data) + assert result == ["repo1", "some-repo", "repo2"] def test_empty_strings_preserved(self): """Test that empty strings in arrays are preserved (might be intentional).""" input_data = ["repo1", "", "repo2", " ", "repo3"] - result = normalize_data_source_ids(input_data) + result = normalize_data_source_names(input_data) assert result == ["repo1", "", "repo2", " ", "repo3"] # All strings preserved def test_non_list_non_string_input(self): """Test handling of unexpected input types.""" - result = normalize_data_source_ids(123) + result = normalize_data_source_names(123) assert result == ["123"] - result = normalize_data_source_ids({"id": "repo1"}) + result = normalize_data_source_names({"id": "repo1"}) assert result == ["{'id': 'repo1'}"] def test_claude_desktop_scenarios(self): """Test specific scenarios from Claude Desktop serialization issues.""" # Scenario 1: JSON string as seen in Claude Desktop logs claude_input_1 = '["67db4097fa23c0a98a8495c2"]' - result_1 = normalize_data_source_ids(claude_input_1) + result_1 = normalize_data_source_names(claude_input_1) assert result_1 == ["67db4097fa23c0a98a8495c2"] # Scenario 2: Plain string as seen in Claude Desktop logs claude_input_2 = "67db4097fa23c0a98a8495c2" - result_2 = normalize_data_source_ids(claude_input_2) + result_2 = normalize_data_source_names(claude_input_2) assert result_2 == ["67db4097fa23c0a98a8495c2"] # Scenario 3: Multiple IDs in JSON string claude_input_3 = '["repo1", "repo2", "workspace1"]' - result_3 = normalize_data_source_ids(claude_input_3) + result_3 = normalize_data_source_names(claude_input_3) assert result_3 == ["repo1", "repo2", "workspace1"] def test_edge_cases(self): """Test various edge cases.""" # Whitespace-only JSON string - assert normalize_data_source_ids("[]") == [] - assert normalize_data_source_ids("[ ]") == [] + assert normalize_data_source_names("[]") == [] + assert normalize_data_source_names("[ ]") == [] # Single item JSON array - assert normalize_data_source_ids('["single"]') == ["single"] + assert normalize_data_source_names('["single"]') == ["single"] # JSON array with empty strings - assert normalize_data_source_ids('["repo1", "", "repo2"]') == ["repo1", "", "repo2"] + assert normalize_data_source_names('["repo1", "", "repo2"]') == ["repo1", "", "repo2"] + + def test_dict_with_name_preferred(self): + """Dict inputs with explicit names should take precedence over IDs.""" + input_data = [ + {"id": "legacy-id", "name": "repo-main"}, + {"name": "workspace:analytics"} + ] + result = normalize_data_source_names(input_data) + assert result == ["repo-main", "workspace:analytics"] class TestParameterNormalizationIntegration: @@ -113,10 +122,10 @@ def test_search_tool_parameter_handling(self): # Verify the function accepts Union[str, List[str]] sig = inspect.signature(codebase_search) - data_source_ids_param = sig.parameters['data_source_ids'] + data_sources_param = sig.parameters['data_sources'] # The annotation should accept both str and List[str] - assert 'Union' in str(data_source_ids_param.annotation) or 'str' in str(data_source_ids_param.annotation) + assert 'Union' in str(data_sources_param.annotation) or 'str' in str(data_sources_param.annotation) def test_consultant_tool_parameter_handling(self): """Test that consultant tool properly normalizes various parameter formats.""" diff --git a/src/tests/test_response_transformer.py b/src/tests/test_response_transformer.py index c28e67b..a36d368 100644 --- a/src/tests/test_response_transformer.py +++ b/src/tests/test_response_transformer.py @@ -323,7 +323,7 @@ def test_data_preservation_with_content(self): "range": {"start": {"line": 18}, "end": {"line": 168}} }, "score": 0.99, - "content": "async def codebase_search(\n ctx: Context,\n query: str,\n data_source_ids: Optional[List[str]] = None,\n mode: str = \"auto\",\n include_content: bool = False\n) -> Dict:", + "content": "async def codebase_search(\n ctx: Context,\n query: str,\n data_sources: Optional[List[str]] = None,\n mode: str = \"auto\",\n include_content: bool = False\n) -> Dict:", "dataSource": { "type": "repository", "id": "685b21230e3822f4efa9d073", diff --git a/src/tests/test_search_tool.py b/src/tests/test_search_tool.py index 193ed68..acb29b9 100644 --- a/src/tests/test_search_tool.py +++ b/src/tests/test_search_tool.py @@ -51,7 +51,7 @@ async def test_codebase_search_returns_xml_string(mock_get_api_key): result = await codebase_search( ctx=ctx, query="authenticate_user", - data_source_ids=["test_id"], + data_sources=["test-name"], mode="auto", include_content=False ) @@ -63,6 +63,11 @@ async def test_codebase_search_returns_xml_string(mock_get_api_key): assert "" in result, "Should contain results tag" assert " str: Args: alive_only: If True (default), returns only data sources in "Alive" state ready for use with chat. - If False, returns all data sources regardless of processing state. + If False, returns all data sources regardless of processing state. Returns: A formatted list of available data sources with the following information for each: - - id: Unique identifier for the data source, used in other API calls - - name: Human-readable name of the repository or workspace + - id: Unique identifier for the data source + - name: Human-readable name of the repository or workspace, used in other API calls + - description: Summary of the codebase contents to guide search and chat usage - type: The type of data source ("Repository" or "Workspace") - url: URL of the repository (for Repository type only) - state: The processing state of the data source (if alive_only=false) @@ -44,7 +45,7 @@ async def get_data_sources(ctx: Context, alive_only: bool = True) -> str: For repositories, the URL can be used to match with local git repositories to provide enhanced context for code understanding. - Use the returned data source IDs with the codebase_search and codebase_consultant functions. + Use the returned data source names with the codebase_search and codebase_consultant functions. """ context: CodeAliveContext = ctx.request_context.lifespan_context @@ -85,7 +86,7 @@ async def get_data_sources(ctx: Context, alive_only: bool = True) -> str: result = f"Available data sources:\n{formatted_data}" # Add usage hint - result += "\n\nYou can use these data source IDs with the codebase_search and codebase_consultant functions." + result += "\n\nYou can use these data source names with the codebase_search and codebase_consultant functions." return result diff --git a/src/tools/search.py b/src/tools/search.py index a8a99e7..1dac98f 100644 --- a/src/tools/search.py +++ b/src/tools/search.py @@ -7,13 +7,17 @@ from fastmcp import Context from core import CodeAliveContext, get_api_key_from_context, log_api_request, log_api_response -from utils import transform_search_response_to_xml, handle_api_error, normalize_data_source_ids +from utils import ( + transform_search_response_to_xml, + handle_api_error, + normalize_data_source_names, +) async def codebase_search( ctx: Context, query: str, - data_source_ids: Optional[Union[str, List[str]]] = None, + data_sources: Optional[Union[str, List[str]]] = None, mode: str = "auto", include_content: bool = False ) -> str: @@ -50,10 +54,10 @@ async def codebase_search( "Where do we parse OAuth callbacks?", "user registration controller" - data_source_ids: List of data source IDs to search in (required). - Can be workspace IDs (search all repositories in the workspace) - or individual repository IDs for targeted searches. - Example: ["67f664fd4c2a00698a52bb6f", "5e8f9a2c1d3b7e4a6c9d0f8e"] + data_sources: List of data source names to search in (required). + Can be workspace names (search all repositories in the workspace) + or individual repository names for targeted searches. + Example: ["enterprise-platform", "payments-team"] mode: Search mode (case-insensitive): - "auto": (Default, recommended) Adaptive semantic search. @@ -70,22 +74,22 @@ async def codebase_search( Examples: 1. Natural-language question (recommended): - codebase_search(query="What is the auth flow?", data_source_ids=["repo123"]) + codebase_search(query="What is the auth flow?", data_sources=["repo123"]) 2. Intent query: - codebase_search(query="Where is user registration logic?", data_source_ids=["repo123"]) + codebase_search(query="Where is user registration logic?", data_sources=["repo123"]) 3. Workspace-wide question: - codebase_search(query="How do microservices talk to the billing API?", data_source_ids=["workspace456"]) + codebase_search(query="How do microservices talk to the billing API?", data_sources=["backend-team"]) 4. Mixed query with a known identifier: - codebase_search(query="Where do we validate JWTs (AuthService)?", data_source_ids=["repo123"]) + codebase_search(query="Where do we validate JWTs (AuthService)?", data_sources=["repo123"]) 5. Concise results without full file contents: - codebase_search(query="Where is password reset handled?", data_source_ids=["repo123"], include_content=false) + codebase_search(query="Where is password reset handled?", data_sources=["repo123"], include_content=false) Note: - - At least one data_source_id must be provided + - At least one data source name must be provided - All data sources must be in "Alive" state - The API key must have access to the specified data sources - Prefer natural-language questions; templates are unnecessary. @@ -94,15 +98,15 @@ async def codebase_search( """ context: CodeAliveContext = ctx.request_context.lifespan_context - # Normalize data source IDs (handles Claude Desktop serialization issues) - data_source_ids = normalize_data_source_ids(data_source_ids) + # Normalize data source names (handles Claude Desktop serialization issues) + data_source_names = normalize_data_source_names(data_sources) # Validate inputs if not query or not query.strip(): return "Query cannot be empty. Please provide a search term, function name, or description of the code you're looking for." - if not data_source_ids or len(data_source_ids) == 0: - await ctx.info("No data source IDs provided. If the API key has exactly one assigned data source, that will be used as default.") + if not data_source_names or len(data_source_names) == 0: + await ctx.info("No data source names provided. If the API key has exactly one assigned data source, that will be used as default.") try: normalized_mode = mode.lower() if mode else "auto" @@ -113,23 +117,23 @@ async def codebase_search( normalized_mode = "auto" # Log the search attempt - if data_source_ids and len(data_source_ids) > 0: - await ctx.info(f"Searching for '{query}' in {len(data_source_ids)} data source(s) using {normalized_mode} mode") + if data_source_names and len(data_source_names) > 0: + await ctx.info(f"Searching for '{query}' in {len(data_source_names)} data source(s) using {normalized_mode} mode") else: await ctx.info(f"Searching for '{query}' using API key's default data source with {normalized_mode} mode") - # Prepare query parameters as a list of tuples to support multiple values for DataSourceIds + # Prepare query parameters as a list of tuples to support multiple values for Names params = [ ("Query", query), ("Mode", normalized_mode), ("IncludeContent", "true" if include_content else "false") ] - if data_source_ids and len(data_source_ids) > 0: - # Add each data source ID as a separate query parameter - for ds_id in data_source_ids: - if ds_id: # Skip None or empty values - params.append(("DataSourceIds", ds_id)) + if data_source_names and len(data_source_names) > 0: + # Add each data source name as a separate query parameter + for ds_name in data_source_names: + if ds_name: # Skip None or empty values + params.append(("Names", ds_name)) else: await ctx.info("Using API key's default data source (if available)") @@ -160,5 +164,5 @@ async def codebase_search( except (httpx.HTTPStatusError, Exception) as e: error_msg = await handle_api_error(ctx, e, "code search") if isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 404: - error_msg = f"Error: Not found (404): One or more data sources could not be found. Check your data_source_ids." + error_msg = f"Error: Not found (404): One or more data sources could not be found. Check your data_sources." return f"{error_msg}" \ No newline at end of file diff --git a/src/utils/__init__.py b/src/utils/__init__.py index ffdc96f..90acfc8 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,11 +1,15 @@ """Utility functions for CodeAlive MCP server.""" from .response_transformer import transform_search_response_to_xml -from .errors import handle_api_error, format_data_source_ids, normalize_data_source_ids +from .errors import ( + handle_api_error, + format_data_source_names, + normalize_data_source_names, +) __all__ = [ 'transform_search_response_to_xml', 'handle_api_error', - 'format_data_source_ids', - 'normalize_data_source_ids' -] \ No newline at end of file + 'format_data_source_names', + 'normalize_data_source_names', +] diff --git a/src/utils/errors.py b/src/utils/errors.py index deb074f..2769ea5 100644 --- a/src/utils/errors.py +++ b/src/utils/errors.py @@ -51,82 +51,63 @@ async def handle_api_error( return f"Error: {error_msg}. Please check your input parameters and try again." -def normalize_data_source_ids(data_sources) -> list: - """ - Normalize data source IDs from various Claude Desktop serialization formats. - - Handles: - - Proper arrays: ["id1", "id2"] - - JSON-encoded strings: "[\"id1\", \"id2\"]" - - Plain strings: "id1" - - None/empty values - - Args: - data_sources: Data sources in any format from Claude Desktop - - Returns: - List of string IDs: ["id1", "id2"] - """ +def normalize_data_source_names(data_sources) -> list: + """Normalize data source names from various serialization formats.""" import json if not data_sources: return [] - # Handle string inputs (Claude Desktop serialization issue) if isinstance(data_sources, str): - # Handle JSON-encoded string - if data_sources.startswith('['): + stripped = data_sources.strip() + if stripped.startswith('['): try: - data_sources = json.loads(data_sources) + data_sources = json.loads(stripped) except json.JSONDecodeError: - # If parsing fails, treat as single ID return [data_sources] else: - # Single ID as string return [data_sources] - # Handle non-list types if not isinstance(data_sources, list): return [str(data_sources)] - # Already a list - extract string IDs result = [] for ds in data_sources: if isinstance(ds, str): result.append(ds) - elif isinstance(ds, dict) and ds.get("id"): - result.append(ds["id"]) + elif isinstance(ds, dict): + if ds.get("name"): + result.append(ds["name"]) + elif ds.get("id"): + # Backward compatibility with legacy ID payloads + result.append(ds["id"]) return result -def format_data_source_ids(data_sources: Optional[list]) -> list: - """ - Convert various data source formats to the API's expected format. - - Handles: - - Simple string IDs: ["id1", "id2"] - - Dict format: [{"id": "id1"}, {"type": "repository", "id": "id2"}] - - Mixed formats - - None/empty values - - Args: - data_sources: List of data sources in various formats - - Returns: - List of dicts with 'id' field: [{"id": "id1"}, {"id": "id2"}] - """ +def format_data_source_names(data_sources: Optional[list]) -> list: + """Convert various data source inputs to a simple list of data source names.""" if not data_sources: return [] - formatted = [] + formatted: list[str] = [] + for ds in data_sources: - if isinstance(ds, str) and ds: - # Simple string ID - formatted.append({"id": ds}) - elif isinstance(ds, dict) and ds.get("id"): - # Already has id field - extract just the id - formatted.append({"id": ds["id"]}) - # Skip None/empty values - - return formatted \ No newline at end of file + if isinstance(ds, str): + name = ds.strip() + if name: + formatted.append(name) + elif isinstance(ds, dict): + name = ds.get("name") or ds.get("id") + if isinstance(name, str): + name = name.strip() + if name: + formatted.append(name) + elif name is not None: + formatted.append(str(name)) + elif ds is not None: + # Fallback: cast other primitive types to string + formatted.append(str(ds)) + + return formatted +