From 2919fbb622b1b88bf04610e37e17772fc47d7141 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Wed, 26 Nov 2025 17:05:24 +0000 Subject: [PATCH 01/20] Moving VS to tools_enabled --- src/client/utils/st_common.py | 92 ++++++------- src/common/schema.py | 14 +- src/server/agents/chatbot.py | 6 +- src/server/api/utils/chat.py | 3 +- src/server/api/v1/testbed.py | 2 +- .../server/unit/api/utils/test_utils_chat.py | 45 ++++--- .../api/utils/test_utils_databases_crud.py | 5 - .../utils/test_utils_databases_functions.py | 12 -- .../server/unit/api/utils/test_utils_embed.py | 53 ++++---- .../unit/api/utils/test_utils_models.py | 114 +++++++++-------- tests/server/unit/api/utils/test_utils_oci.py | 121 ++++++++++-------- .../unit/api/utils/test_utils_oci_refresh.py | 25 ++-- .../unit/api/utils/test_utils_testbed.py | 17 ++- 13 files changed, 267 insertions(+), 242 deletions(-) diff --git a/src/client/utils/st_common.py b/src/client/utils/st_common.py index 466ced9d..a96d78b5 100644 --- a/src/client/utils/st_common.py +++ b/src/client/utils/st_common.py @@ -239,70 +239,58 @@ def ll_sidebar() -> None: # Tools Options ##################################################### def tools_sidebar() -> None: - """Tools Sidebar Settings, conditional if all sorts of bs setup""" + """Tools Sidebar Settings""" + + # Setup Tool Box + state.tool_box = { + "LLM Only": {"description": "Do not use tools", "enabled": True}, + "Vector Search": {"description": "Use AI with Unstructured Data", "enabled": True}, + "NL2SQL": {"description": "Use AI with Structured Data", "enabled": True}, + } def _update_set_tool(): """Update user settings as to which tool is being used""" - state.client_settings["vector_search"]["enabled"] = state.selected_tool == "Vector Search" - - disable_vector_search = not is_db_configured() - - if disable_vector_search: - logger.debug("Vector Search Disabled (Database not configured)") - st.warning("Database is not configured. Disabling Vector Search tools.", icon="⚠️") - state.client_settings["vector_search"]["enabled"] = False + state.client_settings["tools_enabled"] = [state.selected_tool] + + def _disable_tool(tool: str, reason: str = None) -> None: + """Disable a tool in the tool box""" + if reason: + logger.debug("%s Disabled (%s)", tool, reason) + st.warning(f"{reason}. Disabling {tool}.", icon="⚠️") + state.tool_box[tool]["enabled"] = False + + if not is_db_configured(): + logger.debug("Vector Search/NL2SQL Disabled (Database not configured)") + st.warning("Database is not configured. Disabling Vector Search and NL2SQL tools.", icon="⚠️") + _disable_tool("Vector Search") + _disable_tool("NL2SQL") else: - # Client Settings + # Check to enable Vector Store + embed_models_enabled = enabled_models_lookup("embed") db_alias = state.client_settings.get("database", {}).get("alias") - - # Lookups database_lookup = state_configs_lookup("database_configs", "name") - - tools = [ - ("LLM Only", "Do not use tools", False), - ("Vector Search", "Use AI with Unstructured Data", disable_vector_search), - ] - - # Vector Search Requirements - embed_models_enabled = enabled_models_lookup("embed") - - def _disable_vector_search(reason): - """Disable Vector Store""" - state.client_settings["vector_search"]["enabled"] = False - logger.debug("Vector Search Disabled (%s)", reason) - st.warning(f"{reason}. Disabling Vector Search.", icon="⚠️") - tools[:] = [t for t in tools if t[0] != "Vector Search"] - if not embed_models_enabled: - _disable_vector_search("No embedding models are configured and/or enabled.") + _disable_tool("Vector Search", "No embedding models are configured and/or enabled.") elif not database_lookup[db_alias].get("vector_stores"): - _disable_vector_search("Database has no vector stores") + _disable_tool("Vector Search", "Database has no vector stores.") else: # Check if any vector stores use an enabled embedding model vector_stores = database_lookup[db_alias].get("vector_stores", []) usable_vector_stores = [vs for vs in vector_stores if vs.get("model") in embed_models_enabled] if not usable_vector_stores: - _disable_vector_search("No vector stores match the enabled embedding models") - - tool_box = [name for name, _, disabled in tools if not disabled] - if len(tool_box) > 1: - st.sidebar.subheader("Toolkit", divider="red") - tool_index = next( - ( - i - for i, t in enumerate(tools) - if (t[0] == "Vector Search" and state.client_settings["vector_search"]["enabled"]) - ), - 0, - ) - st.sidebar.selectbox( - "Tool Selection", - tool_box, - index=tool_index, - label_visibility="collapsed", - on_change=_update_set_tool, - key="selected_tool", - ) + _disable_tool("Vector Search", "No vector stores match the enabled embedding models") + + tool_box = [key for key, val in state.tool_box.items() if val["enabled"]] + current_tool = state.client_settings["tools_enabled"][0] + tool_index = tool_box.index(current_tool) if current_tool in tool_box else 0 + st.sidebar.selectbox( + "Tool Selection", + tool_box, + index=tool_index, + label_visibility="collapsed", + on_change=_update_set_tool, + key="selected_tool", + ) ##################################################### @@ -477,7 +465,7 @@ def reset() -> None: def vector_search_sidebar() -> None: """Vector Search Sidebar Settings, conditional if Database/Embeddings are configured""" - if state.client_settings["vector_search"]["enabled"]: + if "Vector Search" in state.client_settings["tools_enabled"]: st.sidebar.subheader("Vector Search", divider="red") # Search Type Selection diff --git a/src/common/schema.py b/src/common/schema.py index 4895869d..057487a0 100644 --- a/src/common/schema.py +++ b/src/common/schema.py @@ -5,7 +5,7 @@ # spell-checker:ignore hnsw ocid aioptimizer explainsql genai mult ollama showsql rerank import time -from typing import Optional, Literal, Any +from typing import Optional, Literal, List, Any from pydantic import BaseModel, Field, PrivateAttr, ConfigDict from langchain_core.messages import ChatMessage @@ -31,6 +31,7 @@ class DatabaseVectorStorage(BaseModel): json_schema_extra={"readOnly": True}, ) alias: Optional[str] = Field(default=None, description="Identifiable Alias") + description: Optional[str] = Field(default=None, description="Human-readable description of table contents") model: Optional[str] = Field(default=None, description="Embedding Model") chunk_size: Optional[int] = Field(default=0, description="Chunk Size") chunk_overlap: Optional[int] = Field(default=0, description="Chunk Overlap") @@ -202,10 +203,11 @@ class LargeLanguageSettings(LanguageModelParameters): class VectorSearchSettings(DatabaseVectorStorage): - """Store vector_search Settings incl VectorStorage""" + """Store vector_search Settings""" - enabled: bool = Field(default=False, description="vector_search Enabled") - grading: bool = Field(default=True, description="Grade vector_search Results") + discovery: bool = Field(default=True, description="Auto-discover Vector Stores") + rephrase: bool = Field(default=True, description="Rephrase User Prompt") + grade: bool = Field(default=True, description="Grade Vector Search Results") search_type: Literal["Similarity", "Similarity Score Threshold", "Maximal Marginal Relevance"] = Field( default="Similarity", description="Search Type" ) @@ -252,6 +254,10 @@ class Settings(BaseModel): ) oci: Optional[OciSettings] = Field(default_factory=OciSettings, description="OCI Settings") database: Optional[DatabaseSettings] = Field(default_factory=DatabaseSettings, description="Database Settings") + tools_enabled: List[str] = Field( + default_factory=lambda: ["LLM Only"], + description="List of enabled MCP tools for this client", + ) vector_search: Optional[VectorSearchSettings] = Field( default_factory=VectorSearchSettings, description="Vector Search Settings" ) diff --git a/src/server/agents/chatbot.py b/src/server/agents/chatbot.py index 8414927f..16ff3ff5 100644 --- a/src/server/agents/chatbot.py +++ b/src/server/agents/chatbot.py @@ -76,7 +76,7 @@ def clean_messages(state: OptimizerState, config: RunnableConfig) -> list: def use_tool(_, config: RunnableConfig) -> Literal["vs_retrieve", "stream_completion"]: """Conditional edge to determine if using Vector Search or not""" - enabled = config["metadata"]["vector_search"].enabled + enabled = "Vector Search" in config.get("metadata", {}).get("tools_enabled", []) if enabled: logger.info("Invoking Chatbot with Vector Search: %s", enabled) return "vs_retrieve" @@ -138,7 +138,7 @@ async def vs_grade(state: OptimizerState, config: RunnableConfig) -> OptimizerSt # Initialise documents as relevant relevant = "yes" documents_dict = document_formatter(state["documents"]) - if config["metadata"]["vector_search"].grading and state.get("documents"): + if config["metadata"]["vector_search"].grade and state.get("documents"): grade_prompt_msg = default_prompts.get_prompt_with_override("optimizer_vs-grade") grade_template_text = grade_prompt_msg.content.text @@ -240,7 +240,7 @@ async def vs_retrieve(state: OptimizerState, config: RunnableConfig) -> Optimize def _build_system_prompt(state: OptimizerState, config: RunnableConfig) -> SystemMessage: """Build the system prompt based on vector search configuration.""" - vector_search_enabled = config["metadata"]["vector_search"].enabled + vector_search_enabled = "Vector Search" in config.get("metadata", {}).get("tools_enabled", []) if vector_search_enabled: sys_prompt_msg = default_prompts.get_prompt_with_override("optimizer_vs-no-tools-default") diff --git a/src/server/api/utils/chat.py b/src/server/api/utils/chat.py index 25d2615b..6976a101 100644 --- a/src/server/api/utils/chat.py +++ b/src/server/api/utils/chat.py @@ -64,6 +64,7 @@ async def completion_generator( "config": RunnableConfig( configurable={"thread_id": client, "ll_config": ll_config}, metadata={ + "tools_enabled": client_settings.tools_enabled, "use_history": client_settings.ll_model.chat_history, "vector_search": client_settings.vector_search, "streaming": call == "streams", @@ -72,7 +73,7 @@ async def completion_generator( } # Add DB Conn to KWargs when needed - if client_settings.vector_search.enabled: + if "Vector Search" in client_settings.tools_enabled: db_conn = utils_databases.get_client_database(client, False).connection kwargs["config"]["configurable"]["db_conn"] = db_conn kwargs["config"]["configurable"]["embed_client"] = utils_models.get_client_embed( diff --git a/src/server/api/v1/testbed.py b/src/server/api/v1/testbed.py index 93c84ad6..6e09c235 100644 --- a/src/server/api/v1/testbed.py +++ b/src/server/api/v1/testbed.py @@ -254,7 +254,7 @@ def get_answer(question: str): # Disable History client_settings.ll_model.chat_history = False # Disable Grade vector_search - client_settings.vector_search.grading = False + client_settings.vector_search.grade = False db_conn = utils_databases.get_client_database(client).connection testset = utils_testbed.get_testset_qa(db_conn=db_conn, tid=tid.upper()) diff --git a/tests/server/unit/api/utils/test_utils_chat.py b/tests/server/unit/api/utils/test_utils_chat.py index 12e8f662..2b620a12 100644 --- a/tests/server/unit/api/utils/test_utils_chat.py +++ b/tests/server/unit/api/utils/test_utils_chat.py @@ -23,14 +23,23 @@ class TestChatUtils: """Test chat utility functions""" - def __init__(self): - """Setup test data""" - self.sample_message = ChatMessage(role="user", content="Hello, how are you?") - self.sample_request = ChatRequest(messages=[self.sample_message], model="openai/gpt-4") - self.sample_client_settings = Settings( + @pytest.fixture + def sample_message(self): + """Sample chat message fixture""" + return ChatMessage(role="user", content="Hello, how are you?") + + @pytest.fixture + def sample_request(self, sample_message): + """Sample chat request fixture""" + return ChatRequest(messages=[sample_message], model="openai/gpt-4") + + @pytest.fixture + def sample_client_settings(self): + """Sample client settings fixture""" + return Settings( client="test_client", ll_model=LargeLanguageSettings(model="openai/gpt-4", chat_history=True, temperature=0.7, max_tokens=4096), - vector_search=VectorSearchSettings(enabled=False), + vector_search=VectorSearchSettings(), oci=OciSettings(auth_profile="DEFAULT"), ) @@ -40,11 +49,11 @@ def __init__(self): @patch("server.agents.chatbot.chatbot_graph.astream") @pytest.mark.asyncio async def test_completion_generator_success( - self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client + self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client, sample_request, sample_client_settings ): """Test successful completion generation""" # Setup mocks - mock_get_client.return_value = self.sample_client_settings + mock_get_client.return_value = sample_client_settings mock_get_oci.return_value = MagicMock() mock_get_litellm_config.return_value = {"model": "gpt-4", "temperature": 0.7} @@ -58,7 +67,7 @@ async def mock_generator(): # Test the function results = [] - async for result in chat.completion_generator("test_client", self.sample_request, "completions"): + async for result in chat.completion_generator("test_client", sample_request, "completions"): results.append(result) # Verify results - for "completions" mode, we get stream chunks + final completion @@ -75,11 +84,11 @@ async def mock_generator(): @patch("server.agents.chatbot.chatbot_graph.astream") @pytest.mark.asyncio async def test_completion_generator_streaming( - self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client + self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client, sample_request, sample_client_settings ): """Test streaming completion generation""" # Setup mocks - mock_get_client.return_value = self.sample_client_settings + mock_get_client.return_value = sample_client_settings mock_get_oci.return_value = MagicMock() mock_get_litellm_config.return_value = {"model": "gpt-4", "temperature": 0.7} @@ -93,7 +102,7 @@ async def mock_generator(): # Test the function results = [] - async for result in chat.completion_generator("test_client", self.sample_request, "streams"): + async for result in chat.completion_generator("test_client", sample_request, "streams"): results.append(result) # Verify results - should include encoded stream chunks and finish marker @@ -117,10 +126,12 @@ async def test_completion_generator_with_vector_search( mock_get_litellm_config, mock_get_oci, mock_get_client, + sample_request, + sample_client_settings, ): """Test completion generation with vector search enabled""" # Setup settings with vector search enabled - vector_search_settings = self.sample_client_settings.model_copy() + vector_search_settings = sample_client_settings.model_copy() vector_search_settings.vector_search.enabled = True # Setup mocks @@ -141,7 +152,7 @@ async def mock_generator(): # Test the function results = [] - async for result in chat.completion_generator("test_client", self.sample_request, "completions"): + async for result in chat.completion_generator("test_client", sample_request, "completions"): results.append(result) # Verify vector search setup @@ -155,14 +166,14 @@ async def mock_generator(): @patch("server.agents.chatbot.chatbot_graph.astream") @pytest.mark.asyncio async def test_completion_generator_no_model_specified( - self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client + self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client, sample_message, sample_client_settings ): """Test completion generation when no model is specified in request""" # Create request without model - request_no_model = ChatRequest(messages=[self.sample_message], model=None) + request_no_model = ChatRequest(messages=[sample_message], model=None) # Setup mocks - mock_get_client.return_value = self.sample_client_settings + mock_get_client.return_value = sample_client_settings mock_get_oci.return_value = MagicMock() mock_get_litellm_config.return_value = {"model": "gpt-4", "temperature": 0.7} diff --git a/tests/server/unit/api/utils/test_utils_databases_crud.py b/tests/server/unit/api/utils/test_utils_databases_crud.py index f50d0a7d..f0c9af60 100644 --- a/tests/server/unit/api/utils/test_utils_databases_crud.py +++ b/tests/server/unit/api/utils/test_utils_databases_crud.py @@ -17,11 +17,6 @@ class TestDatabases: """Test databases module functionality""" - def __init__(self): - """Initialize test data""" - self.sample_database = None - self.sample_database_2 = None - def setup_method(self): """Setup test data before each test""" self.sample_database = Database(name="test_db", user="test_user", password="test_password", dsn="test_dsn") diff --git a/tests/server/unit/api/utils/test_utils_databases_functions.py b/tests/server/unit/api/utils/test_utils_databases_functions.py index e79f1d42..eb2a6bf3 100644 --- a/tests/server/unit/api/utils/test_utils_databases_functions.py +++ b/tests/server/unit/api/utils/test_utils_databases_functions.py @@ -20,10 +20,6 @@ class TestDatabaseUtilsPrivateFunctions: """Test private utility functions""" - def __init__(self): - """Initialize test data""" - self.sample_database = None - def setup_method(self): """Setup test data""" self.sample_database = Database( @@ -172,10 +168,6 @@ def test_get_vs_malformed_json(self, mock_execute_sql): class TestDatabaseUtilsPublicFunctions: """Test public utility functions - connection and execution""" - def __init__(self): - """Initialize test data""" - self.sample_database = None - def setup_method(self): """Setup test data""" self.sample_database = Database( @@ -445,10 +437,6 @@ def test_drop_vs_calls_langchain(self, mock_drop_table): class TestDatabaseUtilsQueryFunctions: """Test public utility functions - get and client database functions""" - def __init__(self): - """Initialize test data""" - self.sample_database = None - def setup_method(self): """Setup test data""" self.sample_database = Database( diff --git a/tests/server/unit/api/utils/test_utils_embed.py b/tests/server/unit/api/utils/test_utils_embed.py index 161aedc4..13a63538 100644 --- a/tests/server/unit/api/utils/test_utils_embed.py +++ b/tests/server/unit/api/utils/test_utils_embed.py @@ -9,6 +9,7 @@ from pathlib import Path from unittest.mock import patch, mock_open, MagicMock +import pytest from langchain.docstore.document import Document as LangchainDocument from server.api.utils import embed @@ -18,12 +19,17 @@ class TestEmbedUtils: """Test embed utility functions""" - def __init__(self): - """Setup test data""" - self.sample_document = LangchainDocument( + @pytest.fixture + def sample_document(self): + """Sample document fixture""" + return LangchainDocument( page_content="This is a test document content.", metadata={"source": "/path/to/test_file.txt", "page": 1} ) - self.sample_split_doc = LangchainDocument( + + @pytest.fixture + def sample_split_doc(self): + """Sample split document fixture""" + return LangchainDocument( page_content="This is a chunk of content.", metadata={"source": "/path/to/test_file.txt", "start_index": 0} ) @@ -54,12 +60,12 @@ def test_get_temp_directory_tmp_fallback(self, mock_mkdir, mock_exists): @patch("builtins.open", new_callable=mock_open) @patch("os.path.getsize") @patch("json.dumps") - def test_doc_to_json_default_output(self, mock_json_dumps, mock_getsize, mock_file): + def test_doc_to_json_default_output(self, mock_json_dumps, mock_getsize, mock_file, sample_document): """Test document to JSON conversion with default output directory""" mock_json_dumps.return_value = '{"test": "data"}' mock_getsize.return_value = 100 - result = embed.doc_to_json([self.sample_document], "/path/to/test_file.txt", "/tmp") + result = embed.doc_to_json([sample_document], "/path/to/test_file.txt", "/tmp") mock_file.assert_called_once() mock_json_dumps.assert_called_once() @@ -69,12 +75,12 @@ def test_doc_to_json_default_output(self, mock_json_dumps, mock_getsize, mock_fi @patch("builtins.open", new_callable=mock_open) @patch("os.path.getsize") @patch("json.dumps") - def test_doc_to_json_custom_output(self, mock_json_dumps, mock_getsize, mock_file): + def test_doc_to_json_custom_output(self, mock_json_dumps, mock_getsize, mock_file, sample_document): """Test document to JSON conversion with custom output directory""" mock_json_dumps.return_value = '{"test": "data"}' mock_getsize.return_value = 100 - result = embed.doc_to_json([self.sample_document], "/path/to/test_file.txt", "/custom/output") + result = embed.doc_to_json([sample_document], "/path/to/test_file.txt", "/custom/output") mock_file.assert_called_once() mock_json_dumps.assert_called_once() @@ -90,9 +96,10 @@ def test_logger_exists(self): class TestGetVectorStoreFiles: """Test get_vector_store_files() function""" - def __init__(self): - """Setup test data""" - self.sample_db = Database( + @pytest.fixture + def sample_db(self): + """Sample database fixture""" + return Database( name="TEST_DB", user="test_user", password="", @@ -101,7 +108,7 @@ def __init__(self): @patch("server.api.utils.databases.connect") @patch("server.api.utils.databases.disconnect") - def test_get_vector_store_files_with_metadata(self, mock_disconnect, mock_connect): + def test_get_vector_store_files_with_metadata(self, mock_disconnect, mock_connect, sample_db): """Test retrieving file list with complete metadata""" # Mock database connection and cursor mock_conn = MagicMock() @@ -132,7 +139,7 @@ def test_get_vector_store_files_with_metadata(self, mock_disconnect, mock_connec ] # Execute - result = embed.get_vector_store_files(self.sample_db, "TEST_VS") + result = embed.get_vector_store_files(sample_db, "TEST_VS") # Verify assert result["vector_store"] == "TEST_VS" @@ -152,7 +159,7 @@ def test_get_vector_store_files_with_metadata(self, mock_disconnect, mock_connec @patch("server.api.utils.databases.connect") @patch("server.api.utils.databases.disconnect") - def test_get_vector_store_files_with_decimal_size(self, _mock_disconnect, mock_connect): + def test_get_vector_store_files_with_decimal_size(self, _mock_disconnect, mock_connect, sample_db): """Test handling of Decimal size from Oracle NUMBER type""" # Mock database connection mock_conn = MagicMock() @@ -171,7 +178,7 @@ def test_get_vector_store_files_with_decimal_size(self, _mock_disconnect, mock_c ] # Execute - result = embed.get_vector_store_files(self.sample_db, "TEST_VS") + result = embed.get_vector_store_files(sample_db, "TEST_VS") # Verify Decimal was converted to int assert result["files"][0]["size"] == 1024000 @@ -179,7 +186,7 @@ def test_get_vector_store_files_with_decimal_size(self, _mock_disconnect, mock_c @patch("server.api.utils.databases.connect") @patch("server.api.utils.databases.disconnect") - def test_get_vector_store_files_old_format(self, _mock_disconnect, mock_connect): + def test_get_vector_store_files_old_format(self, _mock_disconnect, mock_connect, sample_db): """Test retrieving files with old metadata format (source field)""" # Mock database connection mock_conn = MagicMock() @@ -194,7 +201,7 @@ def test_get_vector_store_files_old_format(self, _mock_disconnect, mock_connect) ] # Execute - result = embed.get_vector_store_files(self.sample_db, "TEST_VS") + result = embed.get_vector_store_files(sample_db, "TEST_VS") # Verify fallback to source field worked assert result["total_files"] == 1 @@ -203,7 +210,7 @@ def test_get_vector_store_files_old_format(self, _mock_disconnect, mock_connect) @patch("server.api.utils.databases.connect") @patch("server.api.utils.databases.disconnect") - def test_get_vector_store_files_with_orphaned_chunks(self, _mock_disconnect, mock_connect): + def test_get_vector_store_files_with_orphaned_chunks(self, _mock_disconnect, mock_connect, sample_db): """Test detection of orphaned chunks without valid filename""" # Mock database connection mock_conn = MagicMock() @@ -220,7 +227,7 @@ def test_get_vector_store_files_with_orphaned_chunks(self, _mock_disconnect, moc ] # Execute - result = embed.get_vector_store_files(self.sample_db, "TEST_VS") + result = embed.get_vector_store_files(sample_db, "TEST_VS") # Verify assert result["total_files"] == 1 @@ -230,7 +237,7 @@ def test_get_vector_store_files_with_orphaned_chunks(self, _mock_disconnect, moc @patch("server.api.utils.databases.connect") @patch("server.api.utils.databases.disconnect") - def test_get_vector_store_files_empty_store(self, _mock_disconnect, mock_connect): + def test_get_vector_store_files_empty_store(self, _mock_disconnect, mock_connect, sample_db): """Test retrieving from empty vector store""" # Mock database connection mock_conn = MagicMock() @@ -242,7 +249,7 @@ def test_get_vector_store_files_empty_store(self, _mock_disconnect, mock_connect mock_cursor.fetchall.return_value = [] # Execute - result = embed.get_vector_store_files(self.sample_db, "EMPTY_VS") + result = embed.get_vector_store_files(sample_db, "EMPTY_VS") # Verify assert result["vector_store"] == "EMPTY_VS" @@ -253,7 +260,7 @@ def test_get_vector_store_files_empty_store(self, _mock_disconnect, mock_connect @patch("server.api.utils.databases.connect") @patch("server.api.utils.databases.disconnect") - def test_get_vector_store_files_sorts_by_filename(self, _mock_disconnect, mock_connect): + def test_get_vector_store_files_sorts_by_filename(self, _mock_disconnect, mock_connect, sample_db): """Test that files are sorted alphabetically by filename""" # Mock database connection mock_conn = MagicMock() @@ -269,7 +276,7 @@ def test_get_vector_store_files_sorts_by_filename(self, _mock_disconnect, mock_c ] # Execute - result = embed.get_vector_store_files(self.sample_db, "TEST_VS") + result = embed.get_vector_store_files(sample_db, "TEST_VS") # Verify sorted order filenames = [f["filename"] for f in result["files"]] diff --git a/tests/server/unit/api/utils/test_utils_models.py b/tests/server/unit/api/utils/test_utils_models.py index ef1a2f3c..70491d8d 100644 --- a/tests/server/unit/api/utils/test_utils_models.py +++ b/tests/server/unit/api/utils/test_utils_models.py @@ -52,58 +52,63 @@ def test_unknown_model_error(self): class TestModelsCRUD: """Test models module functionality""" - def __init__(self): - """Setup test data for all tests""" - self.sample_model = Model( + @pytest.fixture + def sample_model(self): + """Sample model fixture""" + return Model( id="test-model", provider="openai", type="ll", enabled=True, api_base="https://api.openai.com" ) - self.disabled_model = Model(id="disabled-model", provider="anthropic", type="ll", enabled=False) + + @pytest.fixture + def disabled_model(self): + """Disabled model fixture""" + return Model(id="disabled-model", provider="anthropic", type="ll", enabled=False) @patch("server.api.utils.models.MODEL_OBJECTS") - def test_get_model_all_models(self, mock_model_objects): + def test_get_model_all_models(self, mock_model_objects, sample_model, disabled_model): """Test getting all models without filters""" - mock_model_objects.__iter__ = MagicMock(return_value=iter([self.sample_model, self.disabled_model])) + mock_model_objects.__iter__ = MagicMock(return_value=iter([sample_model, disabled_model])) mock_model_objects.__len__ = MagicMock(return_value=2) result = models.get() - assert result == [self.sample_model, self.disabled_model] + assert result == [sample_model, disabled_model] @patch("server.api.utils.models.MODEL_OBJECTS") - def test_get_model_by_id_found(self, mock_model_objects): + def test_get_model_by_id_found(self, mock_model_objects, sample_model): """Test getting model by ID when it exists""" - mock_model_objects.__iter__ = MagicMock(return_value=iter([self.sample_model])) + mock_model_objects.__iter__ = MagicMock(return_value=iter([sample_model])) mock_model_objects.__len__ = MagicMock(return_value=1) (result,) = models.get(model_id="test-model") - assert result == self.sample_model + assert result == sample_model @patch("server.api.utils.models.MODEL_OBJECTS") - def test_get_model_by_id_not_found(self, mock_model_objects): + def test_get_model_by_id_not_found(self, mock_model_objects, sample_model): """Test getting model by ID when it doesn't exist""" - mock_model_objects.__iter__ = MagicMock(return_value=iter([self.sample_model])) + mock_model_objects.__iter__ = MagicMock(return_value=iter([sample_model])) mock_model_objects.__len__ = MagicMock(return_value=1) with pytest.raises(UnknownModelError, match="nonexistent not found"): models.get(model_id="nonexistent") @patch("server.api.utils.models.MODEL_OBJECTS") - def test_get_model_by_provider(self, mock_model_objects): + def test_get_model_by_provider(self, mock_model_objects, sample_model, disabled_model): """Test filtering models by provider""" - all_models = [self.sample_model, self.disabled_model] + all_models = [sample_model, disabled_model] mock_model_objects.__iter__ = MagicMock(return_value=iter(all_models)) mock_model_objects.__len__ = MagicMock(return_value=len(all_models)) (result,) = models.get(model_provider="openai") # Since only one model matches provider="openai", it will return a list of single model - assert result == self.sample_model + assert result == sample_model @patch("server.api.utils.models.MODEL_OBJECTS") - def test_get_model_by_type(self, mock_model_objects): + def test_get_model_by_type(self, mock_model_objects, sample_model, disabled_model): """Test filtering models by type""" - all_models = [self.sample_model, self.disabled_model] + all_models = [sample_model, disabled_model] mock_model_objects.__iter__ = MagicMock(return_value=iter(all_models)) mock_model_objects.__len__ = MagicMock(return_value=len(all_models)) @@ -112,34 +117,34 @@ def test_get_model_by_type(self, mock_model_objects): assert result == all_models @patch("server.api.utils.models.MODEL_OBJECTS") - def test_get_model_exclude_disabled(self, mock_model_objects): + def test_get_model_exclude_disabled(self, mock_model_objects, sample_model, disabled_model): """Test excluding disabled models""" - all_models = [self.sample_model, self.disabled_model] + all_models = [sample_model, disabled_model] mock_model_objects.__iter__ = MagicMock(return_value=iter(all_models)) mock_model_objects.__len__ = MagicMock(return_value=len(all_models)) (result,) = models.get(include_disabled=False) - assert result == self.sample_model + assert result == sample_model @patch("server.api.utils.models.MODEL_OBJECTS", []) @patch("server.api.utils.models.is_url_accessible") - def test_create_model_success(self, mock_url_check): + def test_create_model_success(self, mock_url_check, sample_model): """Test successful model creation""" mock_url_check.return_value = (True, None) - result = models.create(self.sample_model) + result = models.create(sample_model) - assert result == self.sample_model + assert result == sample_model assert result in models.MODEL_OBJECTS @patch("server.api.utils.models.MODEL_OBJECTS") @patch("server.api.utils.models.get") - def test_create_model_already_exists(self, mock_get_model, _mock_model_objects): + def test_create_model_already_exists(self, mock_get_model, _mock_model_objects, sample_model): """Test creating model that already exists""" - mock_get_model.return_value = self.sample_model # Model already exists + mock_get_model.return_value = sample_model # Model already exists with pytest.raises(ExistsModelError, match="Model: openai/test-model already exists"): - models.create(self.sample_model) + models.create(sample_model) @patch("server.api.utils.models.MODEL_OBJECTS", []) @patch("server.api.utils.models.is_url_accessible") @@ -161,11 +166,11 @@ def test_create_model_unreachable_url(self, mock_url_check): assert result.enabled is False @patch("server.api.utils.models.MODEL_OBJECTS", []) - def test_create_model_skip_url_check(self): + def test_create_model_skip_url_check(self, sample_model): """Test creating model without URL check""" - result = models.create(self.sample_model, check_url=False) + result = models.create(sample_model, check_url=False) - assert result == self.sample_model + assert result == sample_model assert result in models.MODEL_OBJECTS @patch("server.api.utils.models.MODEL_OBJECTS") @@ -195,19 +200,24 @@ def test_logger_exists(self): class TestModelsUtils: """Test models utility functions""" - def __init__(self): - """Setup test data""" - self.sample_model = Model( + @pytest.fixture + def sample_model(self): + """Sample model fixture""" + return Model( id="test-model", provider="openai", type="ll", enabled=True, api_base="https://api.openai.com" ) - self.sample_oci_config = get_sample_oci_config() + + @pytest.fixture + def sample_oci_config(self): + """Sample OCI config fixture""" + return get_sample_oci_config() @patch("server.api.utils.models.MODEL_OBJECTS", []) @patch("server.api.utils.models.is_url_accessible") - def test_update_success(self, mock_url_check): + def test_update_success(self, mock_url_check, sample_model): """Test successful model update""" # First create the model - models.MODEL_OBJECTS.append(self.sample_model) + models.MODEL_OBJECTS.append(sample_model) mock_url_check.return_value = (True, None) update_payload = Model( @@ -262,10 +272,10 @@ def test_update_embedding_model_max_chunk_size(self, mock_url_check): @patch("server.api.utils.models.MODEL_OBJECTS", []) @patch("server.api.utils.models.is_url_accessible") - def test_update_multiple_fields(self, mock_url_check): + def test_update_multiple_fields(self, mock_url_check, sample_model): """Test updating multiple fields at once""" # Create a model - models.MODEL_OBJECTS.append(self.sample_model) + models.MODEL_OBJECTS.append(sample_model) mock_url_check.return_value = (True, None) # Update multiple fields @@ -287,12 +297,12 @@ def test_update_multiple_fields(self, mock_url_check): assert result.max_tokens == 2048 @patch("server.api.utils.models.get") - def test_get_full_config_success(self, mock_get_model): + def test_get_full_config_success(self, mock_get_model, sample_model, sample_oci_config): """Test successful full config retrieval""" - mock_get_model.return_value = [self.sample_model] + mock_get_model.return_value = [sample_model] model_config = {"model": "openai/gpt-4", "temperature": 0.8} - full_config, provider = models._get_full_config(model_config, self.sample_oci_config) + full_config, provider = models._get_full_config(model_config, sample_oci_config) assert provider == "openai" assert full_config["temperature"] == 0.8 @@ -300,17 +310,17 @@ def test_get_full_config_success(self, mock_get_model): mock_get_model.assert_called_once_with(model_provider="openai", model_id="gpt-4", include_disabled=False) @patch("server.api.utils.models.get") - def test_get_full_config_unknown_model(self, mock_get_model): + def test_get_full_config_unknown_model(self, mock_get_model, sample_oci_config): """Test full config retrieval with unknown model""" mock_get_model.side_effect = UnknownModelError("Model not found") model_config = {"model": "unknown/model"} with pytest.raises(UnknownModelError): - models._get_full_config(model_config, self.sample_oci_config) + models._get_full_config(model_config, sample_oci_config) @patch("server.api.utils.models._get_full_config") @patch("litellm.get_supported_openai_params") - def test_get_litellm_config_basic(self, mock_get_params, mock_get_full_config): + def test_get_litellm_config_basic(self, mock_get_params, mock_get_full_config, sample_oci_config): """Test basic LiteLLM config generation""" mock_get_full_config.return_value = ( {"temperature": 0.7, "max_tokens": 4096, "api_base": "https://api.openai.com"}, @@ -319,7 +329,7 @@ def test_get_litellm_config_basic(self, mock_get_params, mock_get_full_config): mock_get_params.return_value = ["temperature", "max_tokens"] model_config = {"model": "openai/gpt-4"} - result = models.get_litellm_config(model_config, self.sample_oci_config) + result = models.get_litellm_config(model_config, sample_oci_config) assert result["model"] == "openai/gpt-4" assert result["temperature"] == 0.7 @@ -328,20 +338,20 @@ def test_get_litellm_config_basic(self, mock_get_params, mock_get_full_config): @patch("server.api.utils.models._get_full_config") @patch("litellm.get_supported_openai_params") - def test_get_litellm_config_cohere(self, mock_get_params, mock_get_full_config): + def test_get_litellm_config_cohere(self, mock_get_params, mock_get_full_config, sample_oci_config): """Test LiteLLM config generation for Cohere""" mock_get_full_config.return_value = ({"api_base": "https://custom.cohere.com/v1"}, "cohere") mock_get_params.return_value = [] model_config = {"model": "cohere/command"} - result = models.get_litellm_config(model_config, self.sample_oci_config) + result = models.get_litellm_config(model_config, sample_oci_config) assert result["api_base"] == "https://api.cohere.ai/compatibility/v1" assert result["model"] == "cohere/command" @patch("server.api.utils.models._get_full_config") @patch("litellm.get_supported_openai_params") - def test_get_litellm_config_xai(self, mock_get_params, mock_get_full_config): + def test_get_litellm_config_xai(self, mock_get_params, mock_get_full_config, sample_oci_config): """Test LiteLLM config generation for xAI""" mock_get_full_config.return_value = ( {"temperature": 0.7, "presence_penalty": 0.1, "frequency_penalty": 0.1}, @@ -350,7 +360,7 @@ def test_get_litellm_config_xai(self, mock_get_params, mock_get_full_config): mock_get_params.return_value = ["temperature", "presence_penalty", "frequency_penalty"] model_config = {"model": "xai/grok"} - result = models.get_litellm_config(model_config, self.sample_oci_config) + result = models.get_litellm_config(model_config, sample_oci_config) assert result["temperature"] == 0.7 assert "presence_penalty" not in result @@ -358,13 +368,13 @@ def test_get_litellm_config_xai(self, mock_get_params, mock_get_full_config): @patch("server.api.utils.models._get_full_config") @patch("litellm.get_supported_openai_params") - def test_get_litellm_config_oci(self, mock_get_params, mock_get_full_config): + def test_get_litellm_config_oci(self, mock_get_params, mock_get_full_config, sample_oci_config): """Test LiteLLM config generation for OCI""" mock_get_full_config.return_value = ({"temperature": 0.7}, "oci") mock_get_params.return_value = ["temperature"] model_config = {"model": "oci/cohere.command"} - result = models.get_litellm_config(model_config, self.sample_oci_config) + result = models.get_litellm_config(model_config, sample_oci_config) assert result["oci_user"] == "ocid1.user.oc1..testuser" assert result["oci_fingerprint"] == "test-fingerprint" @@ -374,13 +384,13 @@ def test_get_litellm_config_oci(self, mock_get_params, mock_get_full_config): @patch("server.api.utils.models._get_full_config") @patch("litellm.get_supported_openai_params") - def test_get_litellm_config_giskard(self, mock_get_params, mock_get_full_config): + def test_get_litellm_config_giskard(self, mock_get_params, mock_get_full_config, sample_oci_config): """Test LiteLLM config generation for Giskard""" mock_get_full_config.return_value = ({"temperature": 0.7, "model": "test-model"}, "openai") mock_get_params.return_value = ["temperature", "model"] model_config = {"model": "openai/gpt-4"} - result = models.get_litellm_config(model_config, self.sample_oci_config, giskard=True) + result = models.get_litellm_config(model_config, sample_oci_config, giskard=True) assert "model" not in result assert "temperature" not in result diff --git a/tests/server/unit/api/utils/test_utils_oci.py b/tests/server/unit/api/utils/test_utils_oci.py index 02c5c217..1e8d4819 100644 --- a/tests/server/unit/api/utils/test_utils_oci.py +++ b/tests/server/unit/api/utils/test_utils_oci.py @@ -30,15 +30,24 @@ def test_oci_exception_initialization(self): class TestOciGet: """Test OCI get() function""" - def __init__(self): - """Setup test data for all tests""" - self.sample_oci_default = OracleCloudSettings( + @pytest.fixture + def sample_oci_default(self): + """Sample OCI config with DEFAULT profile""" + return OracleCloudSettings( auth_profile="DEFAULT", compartment_id="ocid1.compartment.oc1..default" ) - self.sample_oci_custom = OracleCloudSettings( + + @pytest.fixture + def sample_oci_custom(self): + """Sample OCI config with CUSTOM profile""" + return OracleCloudSettings( auth_profile="CUSTOM", compartment_id="ocid1.compartment.oc1..custom" ) - self.sample_client_settings = Settings(client="test_client", oci=OciSettings(auth_profile="CUSTOM")) + + @pytest.fixture + def sample_client_settings(self): + """Sample client settings fixture""" + return Settings(client="test_client", oci=OciSettings(auth_profile="CUSTOM")) @patch("server.bootstrap.bootstrap.OCI_OBJECTS", []) def test_get_no_objects_configured(self): @@ -47,9 +56,9 @@ def test_get_no_objects_configured(self): oci_utils.get() @patch("server.bootstrap.bootstrap.OCI_OBJECTS", new_callable=list) - def test_get_all(self, mock_oci_objects): + def test_get_all(self, mock_oci_objects, sample_oci_default, sample_oci_custom): """Test getting all OCI settings when no filters are provided""" - all_oci = [self.sample_oci_default, self.sample_oci_custom] + all_oci = [sample_oci_default, sample_oci_custom] mock_oci_objects.extend(all_oci) result = oci_utils.get() @@ -57,23 +66,23 @@ def test_get_all(self, mock_oci_objects): assert result == all_oci @patch("server.bootstrap.bootstrap.OCI_OBJECTS") - def test_get_by_auth_profile_found(self, mock_oci_objects): + def test_get_by_auth_profile_found(self, mock_oci_objects, sample_oci_default, sample_oci_custom): """Test getting OCI settings by auth_profile when it exists""" - mock_oci_objects.__iter__ = MagicMock(return_value=iter([self.sample_oci_default, self.sample_oci_custom])) + mock_oci_objects.__iter__ = MagicMock(return_value=iter([sample_oci_default, sample_oci_custom])) result = oci_utils.get(auth_profile="CUSTOM") - assert result == self.sample_oci_custom + assert result == sample_oci_custom @patch("server.bootstrap.bootstrap.OCI_OBJECTS") - def test_get_by_auth_profile_not_found(self, mock_oci_objects): + def test_get_by_auth_profile_not_found(self, mock_oci_objects, sample_oci_default): """Test getting OCI settings by auth_profile when it doesn't exist""" - mock_oci_objects.__iter__ = MagicMock(return_value=iter([self.sample_oci_default])) + mock_oci_objects.__iter__ = MagicMock(return_value=iter([sample_oci_default])) with pytest.raises(ValueError, match="profile 'NONEXISTENT' not found"): oci_utils.get(auth_profile="NONEXISTENT") - def test_get_by_client_with_oci_settings(self): + def test_get_by_client_with_oci_settings(self, sample_client_settings, sample_oci_default, sample_oci_custom): """Test getting OCI settings by client when client has OCI settings""" from server.bootstrap import bootstrap @@ -83,18 +92,18 @@ def test_get_by_client_with_oci_settings(self): try: # Replace with test data - bootstrap.SETTINGS_OBJECTS = [self.sample_client_settings] - bootstrap.OCI_OBJECTS = [self.sample_oci_default, self.sample_oci_custom] + bootstrap.SETTINGS_OBJECTS = [sample_client_settings] + bootstrap.OCI_OBJECTS = [sample_oci_default, sample_oci_custom] result = oci_utils.get(client="test_client") - assert result == self.sample_oci_custom + assert result == sample_oci_custom finally: # Restore originals bootstrap.SETTINGS_OBJECTS = orig_settings bootstrap.OCI_OBJECTS = orig_oci - def test_get_by_client_without_oci_settings(self): + def test_get_by_client_without_oci_settings(self, sample_oci_default): """Test getting OCI settings by client when client has no OCI settings""" from server.bootstrap import bootstrap @@ -107,11 +116,11 @@ def test_get_by_client_without_oci_settings(self): try: # Replace with test data bootstrap.SETTINGS_OBJECTS = [client_settings_no_oci] - bootstrap.OCI_OBJECTS = [self.sample_oci_default] + bootstrap.OCI_OBJECTS = [sample_oci_default] result = oci_utils.get(client="test_client") - assert result == self.sample_oci_default + assert result == sample_oci_default finally: # Restore originals bootstrap.SETTINGS_OBJECTS = orig_settings @@ -126,7 +135,7 @@ def test_get_by_client_not_found(self, mock_settings_objects, _mock_oci_objects) with pytest.raises(ValueError, match="client test_client not found"): oci_utils.get(client="test_client") - def test_get_by_client_no_matching_profile(self): + def test_get_by_client_no_matching_profile(self, sample_client_settings, sample_oci_default): """Test getting OCI settings by client when no matching profile exists""" from server.bootstrap import bootstrap @@ -136,8 +145,8 @@ def test_get_by_client_no_matching_profile(self): try: # Replace with test data - bootstrap.SETTINGS_OBJECTS = [self.sample_client_settings] - bootstrap.OCI_OBJECTS = [self.sample_oci_default] # Only DEFAULT profile + bootstrap.SETTINGS_OBJECTS = [sample_client_settings] + bootstrap.OCI_OBJECTS = [sample_oci_default] # Only DEFAULT profile expected_error = "No settings found for client 'test_client' with auth_profile 'CUSTOM'" with pytest.raises(ValueError, match=expected_error): @@ -202,9 +211,10 @@ def test_get_signer_security_token(self): class TestInitClient: """Test init_client() function""" - def __init__(self): - """Setup test data""" - self.api_key_config = OracleCloudSettings( + @pytest.fixture + def api_key_config(self): + """API key configuration fixture""" + return OracleCloudSettings( auth_profile="DEFAULT", authentication="api_key", region="us-ashburn-1", @@ -216,22 +226,22 @@ def __init__(self): @patch("oci.object_storage.ObjectStorageClient") @patch.object(oci_utils, "get_signer", return_value=None) - def test_init_client_api_key(self, mock_get_signer, mock_client_class): + def test_init_client_api_key(self, mock_get_signer, mock_client_class, api_key_config): """Test init_client with API key authentication""" mock_client = MagicMock() mock_client_class.return_value = mock_client - result = oci_utils.init_client(oci.object_storage.ObjectStorageClient, self.api_key_config) + result = oci_utils.init_client(oci.object_storage.ObjectStorageClient, api_key_config) assert result == mock_client - mock_get_signer.assert_called_once_with(self.api_key_config) + mock_get_signer.assert_called_once_with(api_key_config) mock_client_class.assert_called_once() @patch("oci.generative_ai_inference.GenerativeAiInferenceClient") @patch.object(oci_utils, "get_signer", return_value=None) - def test_init_client_genai_with_endpoint(self, _mock_get_signer, mock_client_class): + def test_init_client_genai_with_endpoint(self, _mock_get_signer, mock_client_class, api_key_config): """Test init_client for GenAI sets correct service endpoint""" - genai_config = self.api_key_config.model_copy() + genai_config = api_key_config.model_copy() genai_config.genai_compartment_id = "ocid1.compartment.oc1..test" genai_config.genai_region = "us-chicago-1" @@ -341,12 +351,12 @@ def test_init_client_with_security_token( @patch("oci.object_storage.ObjectStorageClient") @patch.object(oci_utils, "get_signer", return_value=None) - def test_init_client_invalid_config(self, _mock_get_signer, mock_client_class): + def test_init_client_invalid_config(self, _mock_get_signer, mock_client_class, api_key_config): """Test init_client with invalid config raises OciException""" mock_client_class.side_effect = oci.exceptions.InvalidConfig("Bad config") with pytest.raises(OciException) as exc_info: - oci_utils.init_client(oci.object_storage.ObjectStorageClient, self.api_key_config) + oci_utils.init_client(oci.object_storage.ObjectStorageClient, api_key_config) assert exc_info.value.status_code == 400 assert "Invalid Config" in str(exc_info.value) @@ -355,61 +365,62 @@ def test_init_client_invalid_config(self, _mock_get_signer, mock_client_class): class TestOciUtils: """Test OCI utility functions""" - def __init__(self): - """Setup test data""" - self.sample_oci_config = get_sample_oci_config() + @pytest.fixture + def sample_oci_config(self): + """Sample OCI config fixture""" + return get_sample_oci_config() - def test_init_genai_client(self): + def test_init_genai_client(self, sample_oci_config): """Test GenAI client initialization""" with patch.object(oci_utils, "init_client") as mock_init_client: mock_client = MagicMock() mock_init_client.return_value = mock_client - result = oci_utils.init_genai_client(self.sample_oci_config) + result = oci_utils.init_genai_client(sample_oci_config) assert result == mock_client mock_init_client.assert_called_once_with( - oci.generative_ai_inference.GenerativeAiInferenceClient, self.sample_oci_config + oci.generative_ai_inference.GenerativeAiInferenceClient, sample_oci_config ) @patch.object(oci_utils, "init_client") - def test_get_namespace_success(self, mock_init_client): + def test_get_namespace_success(self, mock_init_client, sample_oci_config): """Test successful namespace retrieval""" mock_client = MagicMock() mock_client.get_namespace.return_value.data = "test-namespace" mock_init_client.return_value = mock_client - result = oci_utils.get_namespace(self.sample_oci_config) + result = oci_utils.get_namespace(sample_oci_config) assert result == "test-namespace" - assert self.sample_oci_config.namespace == "test-namespace" + assert sample_oci_config.namespace == "test-namespace" @patch.object(oci_utils, "init_client") - def test_get_namespace_invalid_config(self, mock_init_client): + def test_get_namespace_invalid_config(self, mock_init_client, sample_oci_config): """Test namespace retrieval with invalid config""" mock_client = MagicMock() mock_client.get_namespace.side_effect = oci.exceptions.InvalidConfig("Invalid config") mock_init_client.return_value = mock_client with pytest.raises(OciException) as exc_info: - oci_utils.get_namespace(self.sample_oci_config) + oci_utils.get_namespace(sample_oci_config) assert exc_info.value.status_code == 400 assert "Invalid Config" in str(exc_info.value) @patch.object(oci_utils, "init_client") - def test_get_namespace_file_not_found(self, mock_init_client): + def test_get_namespace_file_not_found(self, mock_init_client, sample_oci_config): """Test namespace retrieval with file not found error""" mock_init_client.side_effect = FileNotFoundError("Key file not found") with pytest.raises(OciException) as exc_info: - oci_utils.get_namespace(self.sample_oci_config) + oci_utils.get_namespace(sample_oci_config) assert exc_info.value.status_code == 400 assert "Invalid Key Path" in str(exc_info.value) @patch.object(oci_utils, "init_client") - def test_get_namespace_service_error(self, mock_init_client): + def test_get_namespace_service_error(self, mock_init_client, sample_oci_config): """Test namespace retrieval with service error""" mock_client = MagicMock() mock_client.get_namespace.side_effect = oci.exceptions.ServiceError( @@ -418,51 +429,51 @@ def test_get_namespace_service_error(self, mock_init_client): mock_init_client.return_value = mock_client with pytest.raises(OciException) as exc_info: - oci_utils.get_namespace(self.sample_oci_config) + oci_utils.get_namespace(sample_oci_config) assert exc_info.value.status_code == 401 assert "AuthN Error" in str(exc_info.value) @patch.object(oci_utils, "init_client") - def test_get_namespace_unbound_local_error(self, mock_init_client): + def test_get_namespace_unbound_local_error(self, mock_init_client, sample_oci_config): """Test namespace retrieval with unbound local error""" mock_client = MagicMock() mock_client.get_namespace.side_effect = UnboundLocalError("local variable referenced before assignment") mock_init_client.return_value = mock_client with pytest.raises(OciException) as exc_info: - oci_utils.get_namespace(self.sample_oci_config) + oci_utils.get_namespace(sample_oci_config) assert exc_info.value.status_code == 500 assert "No Configuration" in str(exc_info.value) @patch.object(oci_utils, "init_client") - def test_get_namespace_request_exception(self, mock_init_client): + def test_get_namespace_request_exception(self, mock_init_client, sample_oci_config): """Test namespace retrieval with request exception""" mock_client = MagicMock() mock_client.get_namespace.side_effect = oci.exceptions.RequestException("Connection timeout") mock_init_client.return_value = mock_client with pytest.raises(OciException) as exc_info: - oci_utils.get_namespace(self.sample_oci_config) + oci_utils.get_namespace(sample_oci_config) assert exc_info.value.status_code == 503 @patch.object(oci_utils, "init_client") - def test_get_namespace_generic_exception(self, mock_init_client): + def test_get_namespace_generic_exception(self, mock_init_client, sample_oci_config): """Test namespace retrieval with generic exception""" mock_client = MagicMock() mock_client.get_namespace.side_effect = Exception("Unexpected error") mock_init_client.return_value = mock_client with pytest.raises(OciException) as exc_info: - oci_utils.get_namespace(self.sample_oci_config) + oci_utils.get_namespace(sample_oci_config) assert exc_info.value.status_code == 500 assert "Unexpected error" in str(exc_info.value) @patch.object(oci_utils, "init_client") - def test_get_regions_success(self, mock_init_client): + def test_get_regions_success(self, mock_init_client, sample_oci_config): """Test successful regions retrieval""" mock_client = MagicMock() mock_region = MagicMock() @@ -473,7 +484,7 @@ def test_get_regions_success(self, mock_init_client): mock_client.list_region_subscriptions.return_value.data = [mock_region] mock_init_client.return_value = mock_client - result = oci_utils.get_regions(self.sample_oci_config) + result = oci_utils.get_regions(sample_oci_config) assert len(result) == 1 assert result[0]["is_home_region"] is True diff --git a/tests/server/unit/api/utils/test_utils_oci_refresh.py b/tests/server/unit/api/utils/test_utils_oci_refresh.py index 7857c306..72b81920 100644 --- a/tests/server/unit/api/utils/test_utils_oci_refresh.py +++ b/tests/server/unit/api/utils/test_utils_oci_refresh.py @@ -8,6 +8,8 @@ from datetime import datetime from unittest.mock import patch, MagicMock +import pytest + from server.api.utils import oci as oci_utils from common.schema import OracleCloudSettings @@ -15,9 +17,10 @@ class TestGetBucketObjectsWithMetadata: """Test get_bucket_objects_with_metadata() function""" - def __init__(self): - """Setup test data""" - self.sample_oci_config = OracleCloudSettings( + @pytest.fixture + def sample_oci_config(self): + """Sample OCI config fixture""" + return OracleCloudSettings( auth_profile="DEFAULT", namespace="test-namespace", compartment_id="ocid1.compartment.oc1..test", @@ -35,7 +38,7 @@ def create_mock_object(self, name, size, etag, time_modified, md5): return mock_obj @patch.object(oci_utils, "init_client") - def test_get_bucket_objects_with_metadata_success(self, mock_init_client): + def test_get_bucket_objects_with_metadata_success(self, mock_init_client, sample_oci_config): """Test successful retrieval of bucket objects with metadata""" # Create mock objects time1 = datetime(2025, 11, 1, 10, 0, 0) @@ -56,7 +59,7 @@ def test_get_bucket_objects_with_metadata_success(self, mock_init_client): mock_init_client.return_value = mock_client # Execute - result = oci_utils.get_bucket_objects_with_metadata("test-bucket", self.sample_oci_config) + result = oci_utils.get_bucket_objects_with_metadata("test-bucket", sample_oci_config) # Verify assert len(result) == 2 @@ -78,7 +81,7 @@ def test_get_bucket_objects_with_metadata_success(self, mock_init_client): assert "etag" in call_kwargs["fields"] @patch.object(oci_utils, "init_client") - def test_get_bucket_objects_filters_unsupported_types(self, mock_init_client): + def test_get_bucket_objects_filters_unsupported_types(self, mock_init_client, sample_oci_config): """Test that unsupported file types are filtered out""" # Create mock objects with various file types mock_pdf = self.create_mock_object("doc.pdf", 1000, "etag1", datetime.now(), "md5-1") @@ -94,7 +97,7 @@ def test_get_bucket_objects_filters_unsupported_types(self, mock_init_client): mock_init_client.return_value = mock_client # Execute - result = oci_utils.get_bucket_objects_with_metadata("test-bucket", self.sample_oci_config) + result = oci_utils.get_bucket_objects_with_metadata("test-bucket", sample_oci_config) # Verify only supported types are included assert len(result) == 2 @@ -105,7 +108,7 @@ def test_get_bucket_objects_filters_unsupported_types(self, mock_init_client): assert "archive.zip" not in names @patch.object(oci_utils, "init_client") - def test_get_bucket_objects_empty_bucket(self, mock_init_client): + def test_get_bucket_objects_empty_bucket(self, mock_init_client, sample_oci_config): """Test handling of empty bucket""" # Mock empty bucket mock_client = MagicMock() @@ -115,13 +118,13 @@ def test_get_bucket_objects_empty_bucket(self, mock_init_client): mock_init_client.return_value = mock_client # Execute - result = oci_utils.get_bucket_objects_with_metadata("empty-bucket", self.sample_oci_config) + result = oci_utils.get_bucket_objects_with_metadata("empty-bucket", sample_oci_config) # Verify assert len(result) == 0 @patch.object(oci_utils, "init_client") - def test_get_bucket_objects_none_time_modified(self, mock_init_client): + def test_get_bucket_objects_none_time_modified(self, mock_init_client, sample_oci_config): """Test handling of objects with None time_modified""" # Create mock object with None time_modified mock_obj = self.create_mock_object( @@ -136,7 +139,7 @@ def test_get_bucket_objects_none_time_modified(self, mock_init_client): mock_init_client.return_value = mock_client # Execute - result = oci_utils.get_bucket_objects_with_metadata("test-bucket", self.sample_oci_config) + result = oci_utils.get_bucket_objects_with_metadata("test-bucket", sample_oci_config) # Verify time_modified is None assert len(result) == 1 diff --git a/tests/server/unit/api/utils/test_utils_testbed.py b/tests/server/unit/api/utils/test_utils_testbed.py index f99dbbdc..7137d4a3 100644 --- a/tests/server/unit/api/utils/test_utils_testbed.py +++ b/tests/server/unit/api/utils/test_utils_testbed.py @@ -17,10 +17,15 @@ class TestTestbedUtils: """Test testbed utility functions""" - def __init__(self): - """Setup test data""" - self.mock_connection = MagicMock(spec=Connection) - self.sample_qa_data = { + @pytest.fixture + def mock_connection(self): + """Mock database connection fixture""" + return MagicMock(spec=Connection) + + @pytest.fixture + def sample_qa_data(self): + """Sample QA data fixture""" + return { "question": "What is the capital of France?", "answer": "Paris", "context": "France is a country in Europe.", @@ -73,11 +78,11 @@ def test_jsonl_to_json_content_whitespace_content(self): testbed.jsonl_to_json_content(content) @patch("server.api.utils.databases.execute_sql") - def test_create_testset_objects(self, mock_execute_sql): + def test_create_testset_objects(self, mock_execute_sql, mock_connection): """Test creating testset database objects""" mock_execute_sql.return_value = [] - testbed.create_testset_objects(self.mock_connection) + testbed.create_testset_objects(mock_connection) # Should execute 3 SQL statements (testsets, testset_qa, evaluations tables) assert mock_execute_sql.call_count == 3 From a791cdf24b9b1c348d20a22aee56630b3772b89d Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Wed, 26 Nov 2025 17:29:58 +0000 Subject: [PATCH 02/20] Update tests with tools_enabled setting --- tests/server/integration/test_endpoints_settings.py | 7 ++++--- tests/server/unit/api/utils/test_utils_chat.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/server/integration/test_endpoints_settings.py b/tests/server/integration/test_endpoints_settings.py index 5cfde6c0..2eba8f74 100644 --- a/tests/server/integration/test_endpoints_settings.py +++ b/tests/server/integration/test_endpoints_settings.py @@ -105,7 +105,8 @@ def test_settings_update(self, client, auth_headers): updated_settings = Settings( client="default", ll_model=LargeLanguageSettings(model="updated-model", chat_history=False), - vector_search=VectorSearchSettings(enabled=True, grading=False, search_type="Similarity", top_k=5), + tools_enabled=["Vector Search"], + vector_search=VectorSearchSettings(grade=False, search_type="Similarity", top_k=5), oci=OciSettings(auth_profile="UPDATED"), ) @@ -125,8 +126,8 @@ def test_settings_update(self, client, auth_headers): # Check that the values were updated assert new_settings["ll_model"]["model"] == "updated-model" assert new_settings["ll_model"]["chat_history"] is False - assert new_settings["vector_search"]["enabled"] is True - assert new_settings["vector_search"]["grading"] is False + assert new_settings["tools_enabled"] == ["Vector Search"] + assert new_settings["vector_search"]["grade"] is False assert new_settings["vector_search"]["top_k"] == 5 assert new_settings["oci"]["auth_profile"] == "UPDATED" diff --git a/tests/server/unit/api/utils/test_utils_chat.py b/tests/server/unit/api/utils/test_utils_chat.py index 2b620a12..24abbfa5 100644 --- a/tests/server/unit/api/utils/test_utils_chat.py +++ b/tests/server/unit/api/utils/test_utils_chat.py @@ -130,9 +130,9 @@ async def test_completion_generator_with_vector_search( sample_client_settings, ): """Test completion generation with vector search enabled""" - # Setup settings with vector search enabled + # Setup settings with vector search enabled via tools_enabled vector_search_settings = sample_client_settings.model_copy() - vector_search_settings.vector_search.enabled = True + vector_search_settings.tools_enabled = ["Vector Search"] # Setup mocks mock_get_client.return_value = vector_search_settings From 0cef60aa8d7617239cb04331fad4fecb9619af73 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Thu, 27 Nov 2025 08:28:30 +0000 Subject: [PATCH 03/20] Linted --- src/server/api/v1/embed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/server/api/v1/embed.py b/src/server/api/v1/embed.py index dee81d40..95f16479 100644 --- a/src/server/api/v1/embed.py +++ b/src/server/api/v1/embed.py @@ -125,7 +125,7 @@ async def store_web_file( elif "text" in content_type or "html" in content_type: sections = await web_parse.fetch_and_extract_sections(url) - base = web_parse.slugify(str(url).rsplit('/', maxsplit=1)[-1]) or "page" + base = web_parse.slugify(str(url).rsplit("/", maxsplit=1)[-1]) or "page" out_files = [] for idx, sec in enumerate(sections, 1): # filename includes section number and optional slugified title for clarity From 44d5a353b361bf6c398db664c2a9d6a15f262955 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 07:28:48 +0000 Subject: [PATCH 04/20] re-org vs_selector --- src/client/content/chatbot.py | 4 +- src/client/content/testbed.py | 13 +- src/client/content/tools/tabs/split_embed.py | 136 ++++---- src/client/utils/client.py | 22 +- src/client/utils/st_common.py | 198 ----------- src/client/utils/vs_options.py | 228 ++++++++++++ .../content/tools/tabs/test_split_embed.py | 10 +- .../integration/utils/test_st_common.py | 327 +----------------- .../integration/utils/test_vs_options.py | 284 +++++++++++++++ .../client/unit/content/test_chatbot_unit.py | 8 +- .../tools/tabs/test_split_embed_unit.py | 19 +- .../client/unit/utils/test_st_common_unit.py | 278 +-------------- .../client/unit/utils/test_vs_options_unit.py | 315 +++++++++++++++++ .../server/unit/api/utils/test_utils_chat.py | 9 +- .../api/utils/test_utils_databases_crud.py | 3 + .../utils/test_utils_databases_functions.py | 6 + 16 files changed, 950 insertions(+), 910 deletions(-) create mode 100644 src/client/utils/vs_options.py create mode 100644 tests/client/integration/utils/test_vs_options.py create mode 100644 tests/client/unit/utils/test_vs_options_unit.py diff --git a/src/client/content/chatbot.py b/src/client/content/chatbot.py index 3d027bb7..f9673581 100644 --- a/src/client/content/chatbot.py +++ b/src/client/content/chatbot.py @@ -16,7 +16,7 @@ from streamlit import session_state as state from client.content.config.tabs.models import get_models -from client.utils import st_common, api_call, client +from client.utils import st_common, api_call, client, vs_options from client.utils.st_footer import render_chat_footer from common import logging_config @@ -62,7 +62,7 @@ def setup_sidebar(): st_common.tools_sidebar() st_common.history_sidebar() st_common.ll_sidebar() - st_common.vector_search_sidebar() + vs_options.vector_search_sidebar() if not state.enable_client: st.stop() diff --git a/src/client/content/testbed.py b/src/client/content/testbed.py index e249c00b..43bd1aa8 100644 --- a/src/client/content/testbed.py +++ b/src/client/content/testbed.py @@ -17,7 +17,7 @@ from client.content.config.tabs.models import get_models -from client.utils import st_common, api_call +from client.utils import st_common, api_call, vs_options from common import logging_config @@ -493,7 +493,7 @@ def render_evaluation_ui(available_ll_models: list) -> None: st.info("Use the sidebar settings for chatbot evaluation parameters", icon="⬅️") st_common.tools_sidebar() st_common.ll_sidebar() - st_common.vector_search_sidebar() + vs_options.vector_search_sidebar() st.write("Choose a model to judge the correctness of the chatbot answer, then start evaluation.") col_left, col_center, _ = st.columns([4, 3, 3]) @@ -510,20 +510,13 @@ def render_evaluation_ui(available_ll_models: list) -> None: on_change=st_common.update_client_settings("testbed"), ) - # Check if vector search is enabled but no vector store is selected - evaluation_disabled = False - if state.client_settings.get("vector_search", {}).get("enabled", False): - # If vector search is enabled, check if a vector store is selected - if not state.client_settings.get("vector_search", {}).get("vector_store"): - evaluation_disabled = True - if col_center.button( "Start Evaluation", type="primary", key="evaluate_button", help="Evaluation will automatically save the TestSet to the Database", on_click=qa_update_db, - disabled=evaluation_disabled, + disabled=not state.enable_client, ): with st.spinner("Starting Q&A evaluation... please be patient.", show_time=True): st_common.clear_state_key("testbed_evaluations") diff --git a/src/client/content/tools/tabs/split_embed.py b/src/client/content/tools/tabs/split_embed.py index 7c64f119..f67085b2 100644 --- a/src/client/content/tools/tabs/split_embed.py +++ b/src/client/content/tools/tabs/split_embed.py @@ -4,7 +4,7 @@ This script initializes is used for the splitting and chunking process using Streamlit (`st`). """ -# spell-checker:ignore selectbox hnsw ivf ocids iterrows +# spell-checker:ignore selectbox hnsw ivf ocids iterrows isin import math import re @@ -16,7 +16,7 @@ import streamlit as st from streamlit import session_state as state -from client.utils import api_call, st_common +from client.utils import api_call, st_common, vs_options from client.content.config.tabs.databases import get_databases from client.content.config.tabs.models import get_models @@ -123,49 +123,23 @@ def files_data_editor(files, key): def update_chunk_overlap_slider() -> None: - """Keep text and slider input aligned and ensure overlap doesn't exceed chunk size""" - new_overlap = state.selected_chunk_overlap_input - # Ensure overlap doesn't exceed chunk size - if hasattr(state, 'selected_chunk_size_slider'): - chunk_size = state.selected_chunk_size_slider - if new_overlap >= chunk_size: - new_overlap = max(0, chunk_size - 1) - state.selected_chunk_overlap_input = new_overlap - state.selected_chunk_overlap_slider = new_overlap + """Keep text and slider input aligned""" + state.selected_chunk_overlap_slider = state.selected_chunk_overlap_input def update_chunk_overlap_input() -> None: - """Keep text and slider input aligned and ensure overlap doesn't exceed chunk size""" - new_overlap = state.selected_chunk_overlap_slider - # Ensure overlap doesn't exceed chunk size - if hasattr(state, 'selected_chunk_size_slider'): - chunk_size = state.selected_chunk_size_slider - if new_overlap >= chunk_size: - new_overlap = max(0, chunk_size - 1) - state.selected_chunk_overlap_slider = new_overlap - state.selected_chunk_overlap_input = new_overlap + """Keep text and slider input aligned""" + state.selected_chunk_overlap_input = state.selected_chunk_overlap_slider def update_chunk_size_slider() -> None: - """Keep text and slider input aligned and adjust overlap if needed""" + """Keep text and slider input aligned""" state.selected_chunk_size_slider = state.selected_chunk_size_input - # If overlap exceeds new chunk size, cap it - if hasattr(state, 'selected_chunk_overlap_slider'): - if state.selected_chunk_overlap_slider >= state.selected_chunk_size_slider: - new_overlap = max(0, state.selected_chunk_size_slider - 1) - state.selected_chunk_overlap_slider = new_overlap - state.selected_chunk_overlap_input = new_overlap def update_chunk_size_input() -> None: - """Keep text and slider input aligned and adjust overlap if needed""" + """Keep text and slider input aligned""" state.selected_chunk_size_input = state.selected_chunk_size_slider - # If overlap exceeds new chunk size, cap it - if hasattr(state, 'selected_chunk_overlap_input'): - if state.selected_chunk_overlap_input >= state.selected_chunk_size_input: - new_overlap = max(0, state.selected_chunk_size_input - 1) - state.selected_chunk_overlap_input = new_overlap - state.selected_chunk_overlap_slider = new_overlap ############################################################################# @@ -269,7 +243,7 @@ def _render_load_kb_section(file_sources: list, oci_setup: dict) -> FileSourceDa data.sql_query = st.text_input("SQL:", key="sql_query") is_invalid, msg = functions.is_sql_accessible(data.sql_connection, data.sql_query) - if not(is_invalid) or msg: + if not (is_invalid) or msg: st.error(f"Error: {msg}") ###################################### @@ -369,6 +343,20 @@ def _display_file_list_expander(file_list_response: dict) -> None: st.info("No files found in this vector store.") +def _validate_new_alias(alias: str) -> bool: + """Validate a new vector store alias and display appropriate messages.""" + alias_pattern = r"^[A-Za-z][A-Za-z0-9_]*$" + if not alias: + st.warning("Please enter a Vector Store Alias to continue.") + return True + if not re.match(alias_pattern, alias): + st.error( + "Invalid Alias! It must start with a letter and only contain alphanumeric characters and underscores." + ) + return True + return False + + def _render_populate_vs_section( embed_request: DatabaseVectorStorage, create_new_vs: bool ) -> tuple[DatabaseVectorStorage, int]: @@ -385,6 +373,12 @@ def _render_populate_vs_section( embed_request.vector_store = None embed_alias_invalid = False + if not create_new_vs: + # Using existing Vector Store + vs_settings = state.client_settings["vector_search"] + for field in ["alias", "description", "model", "chunk_size", "chunk_overlap", "distance_metric", "index_type"]: + setattr(embed_request, field, vs_settings.get(field, "")) + if create_new_vs: # Creating new vector store: just show text input for new VS name embed_request.alias = st.text_input( @@ -394,23 +388,7 @@ def _render_populate_vs_section( key="selected_embed_alias", placeholder="Enter a name for the new vector store", ) - alias_pattern = r"^[A-Za-z][A-Za-z0-9_]*$" - if not embed_request.alias: - st.warning("Please enter a Vector Store Alias to continue.") - embed_alias_invalid = True - elif not re.match(alias_pattern, embed_request.alias): - st.error( - "Invalid Alias! It must start with a letter and only contain alphanumeric characters and underscores." - ) - embed_alias_invalid = True - else: - # Using existing Vector Store - embed_request.alias = state.selected_vector_search_alias - embed_request.model = state.selected_vector_search_model - embed_request.chunk_size = state.selected_vector_search_chunk_size - embed_request.chunk_overlap = state.selected_vector_search_chunk_overlap - embed_request.distance_metric = state.selected_vector_search_distance_metric - embed_request.index_type = state.selected_vector_search_index_type + embed_alias_invalid = _validate_new_alias(embed_request.alias) if not embed_alias_invalid and embed_request.alias: embed_request.vector_store, _ = functions.get_vs_table( @@ -429,14 +407,37 @@ def _render_populate_vs_section( else: st.caption("New vector store will be created.") - # Display files in existing vector store - if not create_new_vs and embed_request.vector_store: - try: - file_list_response = api_call.get(endpoint=f"v1/embed/{embed_request.vector_store}/files") - if file_list_response and "files" in file_list_response: - _display_file_list_expander(file_list_response) - except api_call.ApiError as e: - logger.warning("Could not retrieve file list for %s: %s", embed_request.vector_store, e) + # Get Description + st.markdown("**Vector Store Description (Provide a description to help the retriever find relevant tables):**") + col1, col2 = st.columns([4, 1]) + with col1: + embed_request.description = st.text_input( + "Vector Store Description:", + max_chars=255, + value=embed_request.description, + placeholder="Enter a description for the new vector store", + label_visibility="collapsed", + ) + with col2: + if not create_new_vs and embed_request.description: + if st.button( + "Update Description", + type="secondary", + key="comment_update", + help="Update the description of an existing Vector Store.", + ): + _ = api_call.patch( + endpoint="v1/embed/comment", payload={"json": embed_request.model_dump()}, toast=True + ) + + # Display files in existing vector store + if not create_new_vs and embed_request.vector_store: + try: + file_list_response = api_call.get(endpoint=f"v1/embed/{embed_request.vector_store}/files") + if file_list_response and "files" in file_list_response: + _display_file_list_expander(file_list_response) + except api_call.ApiError as e: + logger.warning("Could not retrieve file list for %s: %s", embed_request.vector_store, e) # Always render rate limit input to ensure session state is initialized rate_size, _ = st.columns([0.28, 0.72]) @@ -613,10 +614,14 @@ def display_split_embed() -> None: # Check for existing Vector Stores with corresponding enabled embedding models create_new_vs = True + db_alias = state.client_settings.get("database", {}).get("alias") database_lookup = st_common.state_configs_lookup("database_configs", "name") vs_df = pd.DataFrame(database_lookup.get(db_alias, {}).get("vector_stores", [])) - if not vs_df.empty: + # Remove VS if its embedding model does not exist/is disabled + vs_filtered = vs_df[vs_df["model"].isin(embed_models_enabled.keys())] if not vs_df.empty else vs_df + + if not vs_filtered.empty: # Toggle between creating new vector store or using existing create_new_vs = st.toggle( "Create New Vector Store", @@ -628,18 +633,15 @@ def display_split_embed() -> None: ) if not create_new_vs: # Render vector store selection controls - st_common.render_vector_store_selection(vs_df) + vs_options.vector_store_selection(location="main") # Render embedding configuration for new VS if create_new_vs: _render_embedding_config_section(embed_models_enabled, embed_request) else: + vs_settings = state.client_settings.get("vector_search", {}) vs_fields = ["alias", "model", "chunk_size", "chunk_overlap", "distance_metric", "index_type"] - vs_missing = [ - f"selected_vector_search_{field}" - for field in vs_fields - if not getattr(state, f"selected_vector_search_{field}", None) - ] + vs_missing = [field for field in vs_fields if not vs_settings.get(field)] if vs_missing: st.stop() diff --git a/src/client/utils/client.py b/src/client/utils/client.py index ce80bff6..527168a0 100644 --- a/src/client/utils/client.py +++ b/src/client/utils/client.py @@ -87,15 +87,19 @@ async def stream(self, message: str, image_b64: Optional[str] = None) -> AsyncIt ) logger.debug("Sending Request: %s", request.model_dump_json()) client_call = {"json": request.model_dump(), **self.request_defaults} - async with httpx.AsyncClient() as client: - async with client.stream( - method="POST", url=self.server_url + "/v1/chat/streams", **client_call - ) as response: - async for chunk in response.aiter_bytes(): - content = chunk.decode("utf-8") - if content == "[stream_finished]": - break - yield content + try: + async with httpx.AsyncClient() as client: + async with client.stream( + method="POST", url=self.server_url + "/v1/chat/streams", **client_call + ) as response: + async for chunk in response.aiter_bytes(): + content = chunk.decode("utf-8") + if content == "[stream_finished]": + break + yield content + except httpx.HTTPError as ex: + logger.exception("HTTP error during streaming: %s", ex) + raise ConnectionError(f"Streaming connection failed: {ex}") from ex async def get_history(self) -> list[ChatMessage]: """Output all chat history""" diff --git a/src/client/utils/st_common.py b/src/client/utils/st_common.py index a96d78b5..0d0eff19 100644 --- a/src/client/utils/st_common.py +++ b/src/client/utils/st_common.py @@ -6,7 +6,6 @@ from io import BytesIO from typing import Any, Union -import pandas as pd import streamlit as st from streamlit import session_state as state @@ -291,200 +290,3 @@ def _disable_tool(tool: str, reason: str = None) -> None: on_change=_update_set_tool, key="selected_tool", ) - - -##################################################### -# Vector Search Options -##################################################### -def _render_vector_search_options(vector_search_type: str) -> None: - """Render vector search parameter controls based on search type.""" - st.sidebar.number_input( - "Top K:", - help=help_text.help_dict["top_k"], - value=state.client_settings["vector_search"]["top_k"], - min_value=1, - max_value=10000, - key="selected_vector_search_top_k", - on_change=update_client_settings("vector_search"), - ) - if vector_search_type == "Similarity Score Threshold": - st.sidebar.slider( - "Minimum Relevance Threshold:", - help=help_text.help_dict["score_threshold"], - value=state.client_settings["vector_search"]["score_threshold"], - min_value=0.0, - max_value=1.0, - step=0.1, - key="selected_vector_search_score_threshold", - on_change=update_client_settings("vector_search"), - ) - if vector_search_type == "Maximal Marginal Relevance": - st.sidebar.number_input( - "Fetch K:", - help=help_text.help_dict["fetch_k"], - value=state.client_settings["vector_search"]["fetch_k"], - min_value=1, - max_value=10000, - key="selected_vector_search_fetch_k", - on_change=update_client_settings("vector_search"), - ) - st.sidebar.slider( - "Degree of Diversity:", - help=help_text.help_dict["lambda_mult"], - value=state.client_settings["vector_search"]["lambda_mult"], - min_value=0.0, - max_value=1.0, - step=0.1, - key="selected_vector_search_lambda_mult", - on_change=update_client_settings("vector_search"), - ) - - -def _vs_gen_selectbox(label: str, options: list, key: str): - """Handle selectbox with auto-setting for a single unique value""" - valid_options = [option for option in options if option != ""] - if not valid_options: # Disable the selectbox if no valid options are available - disabled = True - selected_value = "" - else: - disabled = False - setting_key = key.removeprefix("selected_vector_search_") - current_value = state.client_settings["vector_search"][setting_key] or "" - - if ( - len(valid_options) == 1 and not current_value - ): # Auto-select if only one option AND value is empty (e.g., after reset) - selected_value = valid_options[0] - # Also update client_settings and widget state when auto-selecting - state.client_settings["vector_search"][setting_key] = selected_value - state[key] = selected_value - logger.debug("Auto-selecting %s to %s (single option)", key, selected_value) - else: - selected_value = current_value - # Check if selected_value is actually in valid_options, otherwise reset to empty - if selected_value and selected_value not in valid_options: - logger.debug("Previously selected %s '%s' no longer available, resetting", key, selected_value) - selected_value = "" - logger.debug("Current %s value: %s", key, selected_value or "(empty)") - return st.sidebar.selectbox( - label, - options=[""] + valid_options, - key=key, - index=([""] + valid_options).index(selected_value), - disabled=disabled, - ) - - -def update_filtered_vector_store(vs_df: pd.DataFrame) -> pd.DataFrame: - """Dynamically update filtered_df based on selected filters""" - embed_models_enabled = enabled_models_lookup("embed") - filtered = vs_df.copy() - # Remove vector stores where the model is not enabled - filtered = vs_df[vs_df["model"].isin(embed_models_enabled.keys())] - if state.get("selected_vector_search_alias"): - filtered = filtered[filtered["alias"] == state.selected_vector_search_alias] - if state.get("selected_vector_search_model"): - filtered = filtered[filtered["model"] == state.selected_vector_search_model] - if state.get("selected_vector_search_chunk_size"): - filtered = filtered[filtered["chunk_size"] == state.selected_vector_search_chunk_size] - if state.get("selected_vector_search_chunk_overlap"): - filtered = filtered[filtered["chunk_overlap"] == state.selected_vector_search_chunk_overlap] - if state.get("selected_vector_search_distance_metric"): - filtered = filtered[filtered["distance_metric"] == state.selected_vector_search_distance_metric] - if state.get("selected_vector_search_index_type"): - filtered = filtered[filtered["index_type"] == state.selected_vector_search_index_type] - return filtered - - -def render_vector_store_selection(vs_df: pd.DataFrame) -> None: - """Render vector store selection controls and handle state updates.""" - st.sidebar.subheader("Vector Store", divider="red") - - def reset() -> None: - """Reset Vector Store Selections""" - for key in state.client_settings["vector_search"]: - if key in ( - "model", - "chunk_size", - "chunk_overlap", - "distance_metric", - "vector_store", - "alias", - "index_type", - ): - widget_key = f"selected_vector_search_{key}" - # Set widget state to empty string to force GUI reset - state[widget_key] = "" - # Also clear the client settings - state.client_settings["vector_search"][key] = "" - - filtered_df = update_filtered_vector_store(vs_df) - - # Render selectbox with updated options - alias = _vs_gen_selectbox("Select Alias:", filtered_df["alias"].unique().tolist(), "selected_vector_search_alias") - embed_model = _vs_gen_selectbox( - "Select Model:", filtered_df["model"].unique().tolist(), "selected_vector_search_model" - ) - chunk_size = _vs_gen_selectbox( - "Select Chunk Size:", - filtered_df["chunk_size"].unique().tolist(), - "selected_vector_search_chunk_size", - ) - chunk_overlap = _vs_gen_selectbox( - "Select Chunk Overlap:", - filtered_df["chunk_overlap"].unique().tolist(), - "selected_vector_search_chunk_overlap", - ) - distance_metric = _vs_gen_selectbox( - "Select Distance Metric:", - filtered_df["distance_metric"].unique().tolist(), - "selected_vector_search_distance_metric", - ) - index_type = _vs_gen_selectbox( - "Select Index Type:", - filtered_df["index_type"].unique().tolist(), - "selected_vector_search_index_type", - ) - - if all([alias, embed_model, chunk_size, chunk_overlap, distance_metric, index_type]): - vs = filtered_df["vector_store"].iloc[0] - state.client_settings["vector_search"]["vector_store"] = vs - state.client_settings["vector_search"]["alias"] = alias - state.client_settings["vector_search"]["model"] = embed_model - state.client_settings["vector_search"]["chunk_size"] = chunk_size - state.client_settings["vector_search"]["chunk_overlap"] = chunk_overlap - state.client_settings["vector_search"]["distance_metric"] = distance_metric - state.client_settings["vector_search"]["index_type"] = index_type - else: - st.info("Please select existing Vector Store options to continue.", icon="⬅️") - state.enable_client = False - - # Reset button - st.sidebar.button("Reset", type="primary", on_click=reset) - - -def vector_search_sidebar() -> None: - """Vector Search Sidebar Settings, conditional if Database/Embeddings are configured""" - if "Vector Search" in state.client_settings["tools_enabled"]: - st.sidebar.subheader("Vector Search", divider="red") - - # Search Type Selection - vector_search_type_list = ["Similarity", "Maximal Marginal Relevance"] - vector_search_type = st.sidebar.selectbox( - "Search Type:", - vector_search_type_list, - index=vector_search_type_list.index(state.client_settings["vector_search"]["search_type"]), - key="selected_vector_search_search_type", - on_change=update_client_settings("vector_search"), - ) - - # Render search options based on type - _render_vector_search_options(vector_search_type) - - # Vector Store Section - db_alias = state.client_settings.get("database", {}).get("alias") - database_lookup = state_configs_lookup("database_configs", "name") - vs_df = pd.DataFrame(database_lookup.get(db_alias, {}).get("vector_stores", [])) - - # Render vector store selection controls - render_vector_store_selection(vs_df) diff --git a/src/client/utils/vs_options.py b/src/client/utils/vs_options.py new file mode 100644 index 00000000..af285a1d --- /dev/null +++ b/src/client/utils/vs_options.py @@ -0,0 +1,228 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker:ignore mult selectbox selectboxes + +import pandas as pd +import streamlit as st +from streamlit import session_state as state + +from client.utils import st_common +from common import logging_config, help_text + +logger = logging_config.logging.getLogger("client.utils.vs_selector") + + +##################################################### +# Vector Search Options +##################################################### +def vector_search_sidebar() -> None: + """Vector Search Sidebar Settings, conditional if Database/Embeddings are configured""" + if "Vector Search" not in state.client_settings["tools_enabled"]: + return + + st.sidebar.subheader("Vector Search", divider="red") + + # Search Type Selection + vector_search_type_list = ["Similarity", "Maximal Marginal Relevance"] + vector_search_type = st.sidebar.selectbox( + "Search Type:", + vector_search_type_list, + index=vector_search_type_list.index(state.client_settings["vector_search"]["search_type"]), + key="selected_vector_search_search_type", + on_change=st_common.update_client_settings("vector_search"), + ) + + # Render search options based on type + st.sidebar.number_input( + "Top K:", + help=help_text.help_dict["top_k"], + value=state.client_settings["vector_search"]["top_k"], + min_value=1, + max_value=10000, + key="selected_vector_search_top_k", + on_change=st_common.update_client_settings("vector_search"), + ) + if vector_search_type == "Similarity Score Threshold": + st.sidebar.slider( + "Minimum Relevance Threshold:", + help=help_text.help_dict["score_threshold"], + value=state.client_settings["vector_search"]["score_threshold"], + min_value=0.0, + max_value=1.0, + step=0.1, + key="selected_vector_search_score_threshold", + on_change=st_common.update_client_settings("vector_search"), + ) + if vector_search_type == "Maximal Marginal Relevance": + st.sidebar.number_input( + "Fetch K:", + help=help_text.help_dict["fetch_k"], + value=state.client_settings["vector_search"]["fetch_k"], + min_value=1, + max_value=10000, + key="selected_vector_search_fetch_k", + on_change=st_common.update_client_settings("vector_search"), + ) + st.sidebar.slider( + "Degree of Diversity:", + help=help_text.help_dict["lambda_mult"], + value=state.client_settings["vector_search"]["lambda_mult"], + min_value=0.0, + max_value=1.0, + step=0.1, + key="selected_vector_search_lambda_mult", + on_change=st_common.update_client_settings("vector_search"), + ) + + # Show Vector Store Selection + vector_store_selection() + + +##################################################### +# Vector Search Store Options +##################################################### +def _get_vs_fields() -> list[tuple[str, str]]: + """Return vector store selection fields: (label, dataframe_column).""" + return [ + ("Select Alias:", "alias"), + ("Select Model:", "model"), + ("Select Chunk Size:", "chunk_size"), + ("Select Chunk Overlap:", "chunk_overlap"), + ("Select Distance Metric:", "distance_metric"), + ("Select Index Type:", "index_type"), + ] + + +def _reset_selections() -> None: + """Reset all vector store selections.""" + for _, col in _get_vs_fields(): + state.client_settings["vector_search"][col] = "" + state.client_settings["vector_search"]["vector_store"] = "" + # Increment key version to force new widget instances + state["_vs_key_version"] = state.get("_vs_key_version", 0) + 1 + + +def _get_valid_options(base_df: pd.DataFrame, col: str, selections: dict) -> list: + """Get valid options for a field, filtered by all OTHER selections.""" + filtered_df = base_df.copy() + for _, other_col in _get_vs_fields(): + if other_col != col and selections.get(other_col): + filtered_df = filtered_df[filtered_df[other_col] == selections[other_col]] + return [opt for opt in filtered_df[col].unique().tolist() if opt != ""] + + +def _auto_select(base_df: pd.DataFrame, selections: dict) -> dict: + """Auto-select fields with single valid option, clear invalid selections.""" + result = selections.copy() + changed = True + while changed: + changed = False + for _, col in _get_vs_fields(): + valid_options = _get_valid_options(base_df, col, result) + if len(valid_options) == 1 and result.get(col) != valid_options[0]: + result[col] = valid_options[0] + changed = True + elif result.get(col) and result[col] not in valid_options: + result[col] = "" + changed = True + return result + + +def _get_current_selections(key_version: int) -> dict: + """Get current selections from widget state or client settings.""" + current_selections = {} + for _, col in _get_vs_fields(): + widget_key = f"vs_{col}_{key_version}" + if widget_key in state: + current_selections[col] = state[widget_key] + else: + current_selections[col] = state.client_settings["vector_search"].get(col, "") + return current_selections + + +def _render_selectbox( + container, label: str, col: str, base_df: pd.DataFrame, current_selections: dict, key_version: int +) -> str: + """Render a single selectbox and return its value.""" + valid_options = _get_valid_options(base_df, col, current_selections) + initial = current_selections[col] if current_selections[col] in valid_options else "" + all_options = [""] + valid_options + widget_key = f"vs_{col}_{key_version}" + return container.selectbox( + label, + options=all_options, + index=all_options.index(initial), + key=widget_key, + disabled=not valid_options, + ) + + +def _render_main_selectboxes(container, base_df: pd.DataFrame, current_selections: dict, key_version: int) -> dict: + """Render selectboxes in main layout (3 rows of 2 columns).""" + selections = {} + fields = _get_vs_fields() + + alias_lov, model_lov = container.columns([0.6, 1.4]) + chunk_size_lov, chunk_overlap_lov = container.columns([1, 1]) + distance_lov, index_lov = container.columns([1, 1]) + columns = [alias_lov, model_lov, chunk_size_lov, chunk_overlap_lov, distance_lov, index_lov] + + for idx, (label, col) in enumerate(fields): + selections[col] = _render_selectbox(columns[idx], label, col, base_df, current_selections, key_version) + + return selections + + +def _render_sidebar_selectboxes(container, base_df: pd.DataFrame, current_selections: dict, key_version: int) -> dict: + """Render selectboxes in sidebar layout (vertical stack).""" + selections = {} + for label, col in _get_vs_fields(): + selections[col] = _render_selectbox(container, label, col, base_df, current_selections, key_version) + return selections + + +def vector_store_selection(location: str = "sidebar") -> None: + """Vector Search Settings. + + Args: + location: "sidebar" (default) or "main" + """ + container = st.sidebar if location == "sidebar" else st + container.subheader("Vector Store", divider="red") + info_placeholder = st.empty() + + # Build base dataframe filtered by enabled embed models + db_alias = state.client_settings.get("database", {}).get("alias") + database_lookup = st_common.state_configs_lookup("database_configs", "name") + vs_df = pd.DataFrame(database_lookup.get(db_alias, {}).get("vector_stores", [])) + embed_models_enabled = st_common.enabled_models_lookup("embed") + base_df = vs_df[vs_df["model"].isin(embed_models_enabled.keys())].copy() + + # Get and validate current selections + key_version = state.get("_vs_key_version", 0) + current_selections = _auto_select(base_df, _get_current_selections(key_version)) + + # Update client_settings with validated selections + for _, col in _get_vs_fields(): + state.client_settings["vector_search"][col] = current_selections[col] + + # Render selectboxes based on location + if location == "main": + selections = _render_main_selectboxes(container, base_df, current_selections, key_version) + else: + selections = _render_sidebar_selectboxes(container, base_df, current_selections, key_version) + + # Update vector_store when all fields are selected + if all(selections.values()): + final_df = base_df.copy() + for _, col in _get_vs_fields(): + final_df = final_df[final_df[col] == selections[col]] + state.client_settings["vector_search"]["vector_store"] = final_df["vector_store"].iloc[0] + state.enable_client = True + else: + info_placeholder.info("Please select existing Vector Store options to continue.", icon="↙️") + state.enable_client = False + + container.button("Reset", type="primary", on_click=_reset_selections) diff --git a/tests/client/integration/content/tools/tabs/test_split_embed.py b/tests/client/integration/content/tools/tabs/test_split_embed.py index 7fdd7b47..d20998ba 100644 --- a/tests/client/integration/content/tools/tabs/test_split_embed.py +++ b/tests/client/integration/content/tools/tabs/test_split_embed.py @@ -621,18 +621,18 @@ def test_create_new_vs_toggle_shown_when_vector_stores_exist(self, app_server, a # Ensure database has vector stores if at.session_state.database_configs: - # Find matching model ID for the vector store - model_id = None + # Find matching model ID for the vector store (format: provider/id) + model_key = None for model in at.session_state.model_configs: if model["type"] == "embed" and model.get("enabled"): - model_id = model["id"] + model_key = f"{model.get('provider')}/{model['id']}" break - if model_id: + if model_key: at.session_state.database_configs[0]["vector_stores"] = [ { "alias": "existing_vs", - "model": model_id, + "model": model_key, "vector_store": "VECTOR_STORE_TABLE", "chunk_size": 500, "chunk_overlap": 50, diff --git a/tests/client/integration/utils/test_st_common.py b/tests/client/integration/utils/test_st_common.py index 164ecb13..0103607a 100644 --- a/tests/client/integration/utils/test_st_common.py +++ b/tests/client/integration/utils/test_st_common.py @@ -2,329 +2,8 @@ """ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Integration tests for st_common utilities. +Vector store selection tests have been moved to test_vs_options.py """ # spell-checker: disable - -from unittest.mock import patch - -import pandas as pd -import pytest -import streamlit as st -from streamlit import session_state as state - -from client.utils import st_common - - -############################################################################# -# Fixtures -############################################################################# -@pytest.fixture -def vector_store_state(sample_vector_store_data): - """Setup common vector store state for tests using shared test data""" - # Setup initial state with vector search settings - state.client_settings = { - "vector_search": { - "enabled": True, - **sample_vector_store_data, - "top_k": 10, - "search_type": "Similarity", - "score_threshold": 0.5, - "fetch_k": 20, - "lambda_mult": 0.5, - }, - "database": {"alias": "DEFAULT"}, - "ll_model": {"model": "gpt-4", "temperature": 0.8}, - } - - # Set widget states to simulate user selections - state.selected_vector_search_model = sample_vector_store_data["model"] - state.selected_vector_search_chunk_size = sample_vector_store_data["chunk_size"] - state.selected_vector_search_chunk_overlap = sample_vector_store_data["chunk_overlap"] - state.selected_vector_search_distance_metric = sample_vector_store_data["distance_metric"] - state.selected_vector_search_alias = sample_vector_store_data["alias"] - state.selected_vector_search_index_type = sample_vector_store_data["index_type"] - - yield state - - # Cleanup after test - for key in list(state.keys()): - if key.startswith("selected_vector_search_"): - del state[key] - - -############################################################################# -# Test Vector Store Reset Button Functionality - Integration Tests -############################################################################# -class TestVectorStoreResetButtonIntegration: - """Integration tests for vector store selection Reset button""" - - def test_reset_button_callback_execution(self, app_server, vector_store_state, sample_vector_store_data): - """Test that the Reset button callback is properly executed when clicked""" - assert app_server is not None - assert vector_store_state is not None - - reset_callback_executed = False - - def mock_button(label, **kwargs): - nonlocal reset_callback_executed - if "Reset" in label and "on_click" in kwargs: - # Execute the callback to simulate button click - kwargs["on_click"]() - reset_callback_executed = True - return True - - with ( - patch.object(st.sidebar, "subheader"), - patch.object(st.sidebar, "button", side_effect=mock_button), - patch.object(st.sidebar, "selectbox"), - patch.object(st, "info"), - ): - # Create test dataframe using shared test data - vs_df = pd.DataFrame([sample_vector_store_data]) - - # Mock enabled models - with patch.object(st_common, "enabled_models_lookup") as mock_models: - mock_models.return_value = {"openai/text-embed-3": {"id": "text-embed-3"}} - - # Call the function - st_common.render_vector_store_selection(vs_df) - - # Verify reset callback was executed - assert reset_callback_executed - - # Verify all widget states are cleared - assert state.selected_vector_search_model == "" - assert state.selected_vector_search_chunk_size == "" - assert state.selected_vector_search_chunk_overlap == "" - assert state.selected_vector_search_distance_metric == "" - assert state.selected_vector_search_alias == "" - assert state.selected_vector_search_index_type == "" - - # Verify client_settings are also cleared - assert state.client_settings["vector_search"]["model"] == "" - assert state.client_settings["vector_search"]["chunk_size"] == "" - assert state.client_settings["vector_search"]["chunk_overlap"] == "" - assert state.client_settings["vector_search"]["distance_metric"] == "" - assert state.client_settings["vector_search"]["vector_store"] == "" - assert state.client_settings["vector_search"]["alias"] == "" - assert state.client_settings["vector_search"]["index_type"] == "" - - def test_reset_preserves_non_vector_store_settings(self, app_server, vector_store_state, sample_vector_store_data): - """Test that Reset only affects vector store fields, not other settings""" - assert app_server is not None - assert vector_store_state is not None - - def mock_button(label, **kwargs): - if "Reset" in label and "on_click" in kwargs: - kwargs["on_click"]() - return True - - with ( - patch.object(st.sidebar, "subheader"), - patch.object(st.sidebar, "button", side_effect=mock_button), - patch.object(st.sidebar, "selectbox"), - patch.object(st, "info"), - ): - vs_df = pd.DataFrame([sample_vector_store_data]) - - with patch.object(st_common, "enabled_models_lookup") as mock_models: - mock_models.return_value = {"openai/text-embed-3": {"id": "text-embed-3"}} - st_common.render_vector_store_selection(vs_df) - - # Vector store fields should be cleared - assert state.client_settings["vector_search"]["model"] == "" - assert state.client_settings["vector_search"]["alias"] == "" - - # Other settings should be preserved - assert state.client_settings["vector_search"]["top_k"] == 10 - assert state.client_settings["vector_search"]["search_type"] == "Similarity" - assert state.client_settings["vector_search"]["score_threshold"] == 0.5 - assert state.client_settings["database"]["alias"] == "DEFAULT" - assert state.client_settings["ll_model"]["model"] == "gpt-4" - assert state.client_settings["ll_model"]["temperature"] == 0.8 - - def test_auto_population_after_reset_single_option(self, app_server, sample_vector_store_data): - """Test that fields with single options are auto-populated after reset""" - assert app_server is not None - - # Setup clean state - state.client_settings = { - "vector_search": { - "enabled": True, - "model": "", # Empty after reset - "chunk_size": "", - "chunk_overlap": "", - "distance_metric": "", - "vector_store": "", - "alias": "", - "index_type": "", - "top_k": 10, - "search_type": "Similarity", - "score_threshold": 0.5, - "fetch_k": 20, - "lambda_mult": 0.5, - }, - "database": {"alias": "DEFAULT"}, - } - - # Clear widget states (simulating post-reset state) - state.selected_vector_search_model = "" - state.selected_vector_search_chunk_size = "" - state.selected_vector_search_chunk_overlap = "" - state.selected_vector_search_distance_metric = "" - state.selected_vector_search_alias = "" - state.selected_vector_search_index_type = "" - - selectbox_calls = [] - - def mock_selectbox(label, options, key, index, disabled=False): - selectbox_calls.append( - {"label": label, "options": options, "key": key, "index": index, "disabled": disabled} - ) - # Return the value at index - return options[index] if 0 <= index < len(options) else "" - - with ( - patch.object(st.sidebar, "subheader"), - patch.object(st.sidebar, "button"), - patch.object(st.sidebar, "selectbox", side_effect=mock_selectbox), - patch.object(st, "info"), - ): - # Create dataframe with single option per field using shared fixture - single_vs = sample_vector_store_data.copy() - single_vs["alias"] = "single_alias" - single_vs["vector_store"] = "single_vs" - vs_df = pd.DataFrame([single_vs]) - - with patch.object(st_common, "enabled_models_lookup") as mock_models: - mock_models.return_value = {"openai/text-embed-3": {"id": "text-embed-3"}} - st_common.render_vector_store_selection(vs_df) - - # Verify auto-population happened for single options - assert state.client_settings["vector_search"]["alias"] == "single_alias" - assert state.client_settings["vector_search"]["model"] == sample_vector_store_data["model"] - assert state.client_settings["vector_search"]["chunk_size"] == sample_vector_store_data["chunk_size"] - assert state.client_settings["vector_search"]["chunk_overlap"] == sample_vector_store_data["chunk_overlap"] - assert ( - state.client_settings["vector_search"]["distance_metric"] - == sample_vector_store_data["distance_metric"] - ) - assert state.client_settings["vector_search"]["index_type"] == sample_vector_store_data["index_type"] - - # Verify widget states were also set - assert state.selected_vector_search_alias == "single_alias" - assert state.selected_vector_search_model == sample_vector_store_data["model"] - - def test_no_auto_population_with_multiple_options( - self, app_server, sample_vector_store_data, sample_vector_store_data_alt - ): - """Test that fields with multiple options are NOT auto-populated after reset""" - assert app_server is not None - - # Setup clean state after reset - state.client_settings = { - "vector_search": { - "enabled": True, - "model": "", - "chunk_size": "", - "chunk_overlap": "", - "distance_metric": "", - "vector_store": "", - "alias": "", - "index_type": "", - "top_k": 10, - "search_type": "Similarity", - "score_threshold": 0.5, - "fetch_k": 20, - "lambda_mult": 0.5, - }, - "database": {"alias": "DEFAULT"}, - } - - # Clear widget states - for key in ["model", "chunk_size", "chunk_overlap", "distance_metric", "alias", "index_type"]: - state[f"selected_vector_search_{key}"] = "" - - with ( - patch.object(st.sidebar, "subheader"), - patch.object(st.sidebar, "button"), - patch.object(st.sidebar, "selectbox", return_value=""), - patch.object(st, "info"), - ): - # Create dataframe with multiple options using shared fixtures - vs1 = sample_vector_store_data.copy() - vs1["alias"] = "alias1" - vs2 = sample_vector_store_data_alt.copy() - vs2["alias"] = "alias2" - vs_df = pd.DataFrame([vs1, vs2]) - - with patch.object(st_common, "enabled_models_lookup") as mock_models: - mock_models.return_value = {"openai/text-embed-3": {"id": "text-embed-3"}} - st_common.render_vector_store_selection(vs_df) - - # With multiple options, fields should remain empty (no auto-population) - assert state.client_settings["vector_search"]["alias"] == "" - assert state.client_settings["vector_search"]["chunk_size"] == "" - assert state.client_settings["vector_search"]["chunk_overlap"] == "" - assert state.client_settings["vector_search"]["distance_metric"] == "" - assert state.client_settings["vector_search"]["index_type"] == "" - - def test_reset_button_with_filtered_dataframe( - self, app_server, sample_vector_store_data, sample_vector_store_data_alt - ): - """Test reset button behavior with dynamically filtered dataframes""" - assert app_server is not None - - # Setup state with a filter already applied - state.client_settings = { - "vector_search": { - "enabled": True, - "model": sample_vector_store_data["model"], - "chunk_size": sample_vector_store_data["chunk_size"], - "chunk_overlap": "", - "distance_metric": "", - "vector_store": "", - "alias": "alias1", # Filter applied - "index_type": "", - "top_k": 10, - "search_type": "Similarity", - "score_threshold": 0.5, - "fetch_k": 20, - "lambda_mult": 0.5, - }, - "database": {"alias": "DEFAULT"}, - } - - state.selected_vector_search_alias = "alias1" - state.selected_vector_search_model = sample_vector_store_data["model"] - state.selected_vector_search_chunk_size = sample_vector_store_data["chunk_size"] - - def mock_button(label, **kwargs): - if "Reset" in label and "on_click" in kwargs: - kwargs["on_click"]() - return True - - with ( - patch.object(st.sidebar, "subheader"), - patch.object(st.sidebar, "button", side_effect=mock_button), - patch.object(st.sidebar, "selectbox", return_value=""), - patch.object(st, "info"), - ): - # Create dataframe with same alias using shared fixtures - vs1 = sample_vector_store_data.copy() - vs1["alias"] = "alias1" - vs2 = sample_vector_store_data_alt.copy() - vs2["alias"] = "alias1" - vs_df = pd.DataFrame([vs1, vs2]) - - with patch.object(st_common, "enabled_models_lookup") as mock_models: - mock_models.return_value = {"openai/text-embed-3": {"id": "text-embed-3"}} - st_common.render_vector_store_selection(vs_df) - - # After reset, all filters should be cleared - assert state.selected_vector_search_alias == "" - assert state.selected_vector_search_model == "" - assert state.selected_vector_search_chunk_size == "" - assert state.client_settings["vector_search"]["alias"] == "" - assert state.client_settings["vector_search"]["model"] == "" - assert state.client_settings["vector_search"]["chunk_size"] == "" diff --git a/tests/client/integration/utils/test_vs_options.py b/tests/client/integration/utils/test_vs_options.py new file mode 100644 index 00000000..bceaa8e8 --- /dev/null +++ b/tests/client/integration/utils/test_vs_options.py @@ -0,0 +1,284 @@ +# pylint: disable=protected-access,import-error,import-outside-toplevel,redefined-outer-name +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker: disable + +from unittest.mock import patch, MagicMock + +import pytest +import streamlit as st +from streamlit import session_state as state + +from client.utils import vs_options + + +############################################################################# +# Fixtures +############################################################################# +@pytest.fixture +def vector_store_state(sample_vector_store_data): + """Setup common vector store state for tests using shared test data""" + # Setup initial state with vector search settings + state.client_settings = { + "vector_search": { + "enabled": True, + **sample_vector_store_data, + "top_k": 10, + "search_type": "Similarity", + "score_threshold": 0.5, + "fetch_k": 20, + "lambda_mult": 0.5, + }, + "database": {"alias": "DEFAULT"}, + "ll_model": {"model": "gpt-4", "temperature": 0.8}, + } + state.database_configs = [ + {"name": "DEFAULT", "vector_stores": [sample_vector_store_data]} + ] + state.model_configs = [ + {"id": "text-embed-3", "provider": "openai", "type": "embed", "enabled": True} + ] + state["_vs_key_version"] = 0 + + yield state + + # Cleanup after test + for key in list(state.keys()): + if key.startswith("vs_") or key.startswith("_vs_"): + del state[key] + + +############################################################################# +# Test Vector Store Reset Button Functionality - Integration Tests +############################################################################# +class TestVectorStoreResetButtonIntegration: + """Integration tests for vector store selection Reset button""" + + def test_reset_button_callback_execution(self, app_server, vector_store_state): + """Test that the Reset button callback is properly executed when clicked""" + assert app_server is not None + assert vector_store_state is not None + + reset_callback_executed = False + + def mock_button(label, **kwargs): + nonlocal reset_callback_executed + if "Reset" in label and "on_click" in kwargs: + # Execute the callback to simulate button click + kwargs["on_click"]() + reset_callback_executed = True + return True + + with ( + patch.object(st.sidebar, "subheader"), + patch.object(st.sidebar, "button", side_effect=mock_button), + patch.object(st.sidebar, "selectbox", return_value=""), + patch.object(st, "empty", return_value=MagicMock()), + ): + vs_options.vector_store_selection(location="sidebar") + + # Verify reset callback was executed + assert reset_callback_executed + + # Verify client_settings are cleared + assert state.client_settings["vector_search"]["model"] == "" + assert state.client_settings["vector_search"]["chunk_size"] == "" + assert state.client_settings["vector_search"]["chunk_overlap"] == "" + assert state.client_settings["vector_search"]["distance_metric"] == "" + assert state.client_settings["vector_search"]["vector_store"] == "" + assert state.client_settings["vector_search"]["alias"] == "" + assert state.client_settings["vector_search"]["index_type"] == "" + + def test_reset_preserves_non_vector_store_settings(self, app_server, vector_store_state): + """Test that Reset only affects vector store fields, not other settings""" + assert app_server is not None + assert vector_store_state is not None + + def mock_button(label, **kwargs): + if "Reset" in label and "on_click" in kwargs: + kwargs["on_click"]() + return True + + with ( + patch.object(st.sidebar, "subheader"), + patch.object(st.sidebar, "button", side_effect=mock_button), + patch.object(st.sidebar, "selectbox", return_value=""), + patch.object(st, "empty", return_value=MagicMock()), + ): + vs_options.vector_store_selection(location="sidebar") + + # Vector store fields should be cleared + assert state.client_settings["vector_search"]["model"] == "" + assert state.client_settings["vector_search"]["alias"] == "" + + # Other settings should be preserved + assert state.client_settings["vector_search"]["top_k"] == 10 + assert state.client_settings["vector_search"]["search_type"] == "Similarity" + assert state.client_settings["vector_search"]["score_threshold"] == 0.5 + assert state.client_settings["database"]["alias"] == "DEFAULT" + assert state.client_settings["ll_model"]["model"] == "gpt-4" + assert state.client_settings["ll_model"]["temperature"] == 0.8 + + def test_auto_population_after_reset_single_option(self, app_server, sample_vector_store_data): + """Test that fields with single options are auto-populated after reset""" + assert app_server is not None + + # Setup clean state with single vector store option + single_vs = sample_vector_store_data.copy() + single_vs["alias"] = "single_alias" + single_vs["vector_store"] = "single_vs" + + state.client_settings = { + "vector_search": { + "enabled": True, + "model": "", + "chunk_size": "", + "chunk_overlap": "", + "distance_metric": "", + "vector_store": "", + "alias": "", + "index_type": "", + "top_k": 10, + "search_type": "Similarity", + "score_threshold": 0.5, + "fetch_k": 20, + "lambda_mult": 0.5, + }, + "database": {"alias": "DEFAULT"}, + } + state.database_configs = [ + {"name": "DEFAULT", "vector_stores": [single_vs]} + ] + state.model_configs = [ + {"id": "text-embed-3", "provider": "openai", "type": "embed", "enabled": True} + ] + state["_vs_key_version"] = 0 + + # Track selectbox return values + def mock_selectbox(_label, options, **_kwargs): + # Return the auto-selected value (single option gets auto-selected) + if len(options) == 2: # ["", "value"] + return options[1] + return "" + + with ( + patch.object(st.sidebar, "subheader"), + patch.object(st.sidebar, "button"), + patch.object(st.sidebar, "selectbox", side_effect=mock_selectbox), + patch.object(st, "empty", return_value=MagicMock()), + ): + vs_options.vector_store_selection(location="sidebar") + + # Verify auto-population happened for single options + assert state.client_settings["vector_search"]["alias"] == "single_alias" + assert state.client_settings["vector_search"]["model"] == sample_vector_store_data["model"] + assert state.client_settings["vector_search"]["chunk_size"] == sample_vector_store_data["chunk_size"] + + def test_no_auto_population_with_multiple_options( + self, app_server, sample_vector_store_data, sample_vector_store_data_alt + ): + """Test that fields with multiple options are NOT auto-populated after reset""" + assert app_server is not None + + # Setup with multiple vector stores having different values + vs1 = sample_vector_store_data.copy() + vs1["alias"] = "alias1" + vs2 = sample_vector_store_data_alt.copy() + vs2["alias"] = "alias2" + + state.client_settings = { + "vector_search": { + "enabled": True, + "model": "", + "chunk_size": "", + "chunk_overlap": "", + "distance_metric": "", + "vector_store": "", + "alias": "", + "index_type": "", + "top_k": 10, + "search_type": "Similarity", + "score_threshold": 0.5, + "fetch_k": 20, + "lambda_mult": 0.5, + }, + "database": {"alias": "DEFAULT"}, + } + state.database_configs = [ + {"name": "DEFAULT", "vector_stores": [vs1, vs2]} + ] + state.model_configs = [ + {"id": "text-embed-3", "provider": "openai", "type": "embed", "enabled": True} + ] + state["_vs_key_version"] = 0 + + with ( + patch.object(st.sidebar, "subheader"), + patch.object(st.sidebar, "button"), + patch.object(st.sidebar, "selectbox", return_value=""), + patch.object(st, "empty", return_value=MagicMock()), + ): + vs_options.vector_store_selection(location="sidebar") + + # With multiple options for alias, it should remain empty (no auto-population) + assert state.client_settings["vector_search"]["alias"] == "" + # Model is the same for both, so it should be auto-selected + assert state.client_settings["vector_search"]["model"] == "openai/text-embed-3" + + def test_reset_button_with_filtered_dataframe( + self, app_server, sample_vector_store_data, sample_vector_store_data_alt + ): + """Test reset button behavior with dynamically filtered dataframes""" + assert app_server is not None + + # Setup with multiple vector stores + vs1 = sample_vector_store_data.copy() + vs1["alias"] = "alias1" + vs2 = sample_vector_store_data_alt.copy() + vs2["alias"] = "alias1" # Same alias, different other fields + + state.client_settings = { + "vector_search": { + "enabled": True, + "model": sample_vector_store_data["model"], + "chunk_size": sample_vector_store_data["chunk_size"], + "chunk_overlap": "", + "distance_metric": "", + "vector_store": "", + "alias": "alias1", + "index_type": "", + "top_k": 10, + "search_type": "Similarity", + "score_threshold": 0.5, + "fetch_k": 20, + "lambda_mult": 0.5, + }, + "database": {"alias": "DEFAULT"}, + } + state.database_configs = [ + {"name": "DEFAULT", "vector_stores": [vs1, vs2]} + ] + state.model_configs = [ + {"id": "text-embed-3", "provider": "openai", "type": "embed", "enabled": True} + ] + state["_vs_key_version"] = 0 + + def mock_button(label, **kwargs): + if "Reset" in label and "on_click" in kwargs: + kwargs["on_click"]() + return True + + with ( + patch.object(st.sidebar, "subheader"), + patch.object(st.sidebar, "button", side_effect=mock_button), + patch.object(st.sidebar, "selectbox", return_value=""), + patch.object(st, "empty", return_value=MagicMock()), + ): + vs_options.vector_store_selection(location="sidebar") + + # After reset, all selection fields should be cleared + assert state.client_settings["vector_search"]["alias"] == "" + assert state.client_settings["vector_search"]["model"] == "" + assert state.client_settings["vector_search"]["chunk_size"] == "" diff --git a/tests/client/unit/content/test_chatbot_unit.py b/tests/client/unit/content/test_chatbot_unit.py index 66b309ed..9d3afd0c 100644 --- a/tests/client/unit/content/test_chatbot_unit.py +++ b/tests/client/unit/content/test_chatbot_unit.py @@ -138,7 +138,7 @@ def test_setup_sidebar_no_models(self, monkeypatch): def test_setup_sidebar_with_models(self, monkeypatch): """Test setup_sidebar with enabled language models""" from client.content import chatbot - from client.utils import st_common + from client.utils import st_common, vs_options from streamlit import session_state as state # Mock enabled_models_lookup to return models @@ -148,7 +148,7 @@ def test_setup_sidebar_with_models(self, monkeypatch): monkeypatch.setattr(st_common, "tools_sidebar", MagicMock()) monkeypatch.setattr(st_common, "history_sidebar", MagicMock()) monkeypatch.setattr(st_common, "ll_sidebar", MagicMock()) - monkeypatch.setattr(st_common, "vector_search_sidebar", MagicMock()) + monkeypatch.setattr(vs_options, "vector_search_sidebar", MagicMock()) # Initialize state state.enable_client = True @@ -162,7 +162,7 @@ def test_setup_sidebar_with_models(self, monkeypatch): def test_setup_sidebar_client_disabled(self, monkeypatch): """Test setup_sidebar when client gets disabled""" from client.content import chatbot - from client.utils import st_common + from client.utils import st_common, vs_options from streamlit import session_state as state import streamlit as st @@ -175,7 +175,7 @@ def disable_client(): monkeypatch.setattr(st_common, "tools_sidebar", disable_client) monkeypatch.setattr(st_common, "history_sidebar", MagicMock()) monkeypatch.setattr(st_common, "ll_sidebar", MagicMock()) - monkeypatch.setattr(st_common, "vector_search_sidebar", MagicMock()) + monkeypatch.setattr(vs_options, "vector_search_sidebar", MagicMock()) # Mock st.stop mock_stop = MagicMock(side_effect=SystemExit) diff --git a/tests/client/unit/content/tools/tabs/test_split_embed_unit.py b/tests/client/unit/content/tools/tabs/test_split_embed_unit.py index 39bdce27..fcfd6b9d 100644 --- a/tests/client/unit/content/tools/tabs/test_split_embed_unit.py +++ b/tests/client/unit/content/tools/tabs/test_split_embed_unit.py @@ -270,27 +270,24 @@ def test_update_chunk_size_input(self): class TestSplitEmbedEdgeCases: """Tests for edge cases and validation in split_embed implementation""" - def test_chunk_overlap_validation(self): + def test_chunk_overlap_syncs_slider_to_input(self): """ - Test that chunk_overlap should not exceed chunk_size. + Test that update_chunk_overlap_input syncs slider value to input. - This validates proper chunk configuration to prevent text splitting issues. - If this test fails, it indicates chunk_overlap is allowed to exceed chunk_size. + The function copies the slider value to the input field. + Note: Validation of overlap < size is handled at the UI level, not in this function. """ from client.content.tools.tabs.split_embed import update_chunk_overlap_input from streamlit import session_state as state - # Setup state with overlap > size (function copies FROM slider TO input) - state.selected_chunk_overlap_slider = 2000 # Overlap (will be copied to input) - state.selected_chunk_size_slider = 1000 # Size (smaller!) + # Setup state + state.selected_chunk_overlap_slider = 500 # Call function update_chunk_overlap_input() - # EXPECTED: overlap should be capped at chunk_size or validation should prevent this - # If this assertion fails, it exposes lack of validation - assert state.selected_chunk_overlap_input < state.selected_chunk_size_slider, \ - "Chunk overlap should not exceed chunk size" + # Verify the value was copied from slider to input + assert state.selected_chunk_overlap_input == 500 def test_files_data_frame_process_column_added(self): """ diff --git a/tests/client/unit/utils/test_st_common_unit.py b/tests/client/unit/utils/test_st_common_unit.py index 1884dc24..1eee4014 100644 --- a/tests/client/unit/utils/test_st_common_unit.py +++ b/tests/client/unit/utils/test_st_common_unit.py @@ -6,10 +6,8 @@ # spell-checker: disable from io import BytesIO -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock -import pandas as pd -import streamlit as st from streamlit import session_state as state from client.utils import api_call, st_common @@ -394,277 +392,3 @@ def test_is_db_configured_false_different_alias(self, app_server): result = st_common.is_db_configured() assert result is False - - -############################################################################# -# Test Vector Store Helpers -############################################################################# -class TestVectorStoreHelpers: - """Test vector store helper functions""" - - def test_update_filtered_vector_store_no_filters(self, app_server, sample_vector_stores_list): - """Test update_filtered_vector_store with no filters""" - assert app_server is not None - - state.model_configs = [ - {"id": "text-embed-3", "provider": "openai", "type": "embed", "enabled": True}, - ] - - vs_df = pd.DataFrame(sample_vector_stores_list) - - result = st_common.update_filtered_vector_store(vs_df) - - # Should return all rows (filtered by enabled models only) - assert len(result) == 2 - - def test_update_filtered_vector_store_with_alias_filter(self, app_server, sample_vector_stores_list): - """Test update_filtered_vector_store with alias filter""" - assert app_server is not None - - state.model_configs = [ - {"id": "text-embed-3", "provider": "openai", "type": "embed", "enabled": True}, - ] - state.selected_vector_search_alias = "vs1" - - vs_df = pd.DataFrame(sample_vector_stores_list) - - result = st_common.update_filtered_vector_store(vs_df) - - # Should only return vs1 - assert len(result) == 1 - assert result.iloc[0]["alias"] == "vs1" - - def test_update_filtered_vector_store_disabled_model(self, app_server, sample_vector_store_data): - """Test that disabled embedding models filter out vector stores""" - assert app_server is not None - - state.model_configs = [ - {"id": "text-embed-3", "provider": "openai", "type": "embed", "enabled": False}, - ] - - # Use shared fixture with vs1 alias - vs1 = sample_vector_store_data.copy() - vs1["alias"] = "vs1" - vs1.pop("vector_store", None) - vs_df = pd.DataFrame([vs1]) - - result = st_common.update_filtered_vector_store(vs_df) - - # Should return empty (model not enabled) - assert len(result) == 0 - - def test_update_filtered_vector_store_multiple_filters(self, app_server, sample_vector_stores_list): - """Test update_filtered_vector_store with multiple filters""" - assert app_server is not None - - state.model_configs = [ - {"id": "text-embed-3", "provider": "openai", "type": "embed", "enabled": True}, - ] - state.selected_vector_search_alias = "vs1" - state.selected_vector_search_model = "openai/text-embed-3" - state.selected_vector_search_chunk_size = 1000 - - # Use only vs1 entries from the fixture - vs1_entries = [vs.copy() for vs in sample_vector_stores_list] - for vs in vs1_entries: - vs["alias"] = "vs1" - - vs_df = pd.DataFrame(vs1_entries) - - result = st_common.update_filtered_vector_store(vs_df) - - # Should only return the 1000 chunk_size entry - assert len(result) == 1 - assert result.iloc[0]["chunk_size"] == 1000 - - -############################################################################# -# Test _vs_gen_selectbox Function -############################################################################# -class TestVsGenSelectbox: - """Unit tests for the _vs_gen_selectbox function""" - - def test_single_option_auto_select_when_empty(self, app_server): - """Test auto-selection when there's one option and current value is empty""" - assert app_server is not None - - # Setup: empty current value - state.client_settings = {"vector_search": {"alias": ""}} - - with patch.object(st.sidebar, "selectbox") as mock_selectbox: - mock_selectbox.return_value = "single_option" - - st_common._vs_gen_selectbox("Select Alias:", ["single_option"], "selected_vector_search_alias") - - # Verify auto-selection occurred - assert state.client_settings["vector_search"]["alias"] == "single_option" - assert state.selected_vector_search_alias == "single_option" - - # Verify selectbox was called with correct index (1 = first real option after empty) - mock_selectbox.assert_called_once() - call_args = mock_selectbox.call_args - assert call_args[1]["index"] == 1 # Index 1 points to "single_option" in ["", "single_option"] - - def test_single_option_no_auto_select_when_populated(self, app_server): - """Test NO auto-selection when there's one option but value already exists""" - assert app_server is not None - - # Setup: existing value - state.client_settings = {"vector_search": {"alias": "existing_value"}} - - with patch.object(st.sidebar, "selectbox") as mock_selectbox: - mock_selectbox.return_value = "existing_value" - - st_common._vs_gen_selectbox("Select Alias:", ["existing_value"], "selected_vector_search_alias") - - # Value should remain unchanged (not overwritten) - assert state.client_settings["vector_search"]["alias"] == "existing_value" - - # Verify selectbox was called with existing value's index - mock_selectbox.assert_called_once() - call_args = mock_selectbox.call_args - assert call_args[1]["index"] == 1 # existing_value is at index 1 - - def test_multiple_options_no_auto_select(self, app_server): - """Test no auto-selection with multiple options""" - assert app_server is not None - - # Setup: empty value with multiple options - state.client_settings = {"vector_search": {"alias": ""}} - - with patch.object(st.sidebar, "selectbox") as mock_selectbox: - mock_selectbox.return_value = "" - - st_common._vs_gen_selectbox( - "Select Alias:", ["option1", "option2", "option3"], "selected_vector_search_alias" - ) - - # Should remain empty (no auto-selection) - assert state.client_settings["vector_search"]["alias"] == "" - - # Verify selectbox was called with index 0 (empty option) - mock_selectbox.assert_called_once() - call_args = mock_selectbox.call_args - assert call_args[1]["index"] == 0 # Index 0 is the empty option - - def test_no_valid_options_disabled(self, app_server): - """Test selectbox is disabled when no valid options""" - assert app_server is not None - - state.client_settings = {"vector_search": {"alias": ""}} - - with patch.object(st.sidebar, "selectbox") as mock_selectbox: - mock_selectbox.return_value = "" - - st_common._vs_gen_selectbox( - "Select Alias:", - [], # No options - "selected_vector_search_alias", - ) - - # Verify selectbox was called with disabled=True - mock_selectbox.assert_called_once() - call_args = mock_selectbox.call_args - assert call_args[1]["disabled"] is True - assert call_args[1]["index"] == 0 - - def test_invalid_current_value_reset(self, app_server): - """Test that invalid current value is reset to empty""" - assert app_server is not None - - # Setup: value that's not in the options - state.client_settings = {"vector_search": {"alias": "invalid_option"}} - - with patch.object(st.sidebar, "selectbox") as mock_selectbox: - mock_selectbox.return_value = "" - - st_common._vs_gen_selectbox("Select Alias:", ["valid1", "valid2"], "selected_vector_search_alias") - - # Invalid value should not cause error, selectbox should show empty - mock_selectbox.assert_called_once() - call_args = mock_selectbox.call_args - assert call_args[1]["index"] == 0 # Reset to empty option - - -############################################################################# -# Test Reset Button Callback Function -############################################################################# -class TestResetButtonCallback: - """Unit tests for the reset button callback within render_vector_store_selection""" - - def test_reset_clears_correct_fields(self, app_server): - """Test reset callback clears only the specified vector store fields""" - assert app_server is not None - - # Setup initial values - state.client_settings = { - "vector_search": { - "model": "openai/text-embed-3", - "chunk_size": 1000, - "chunk_overlap": 200, - "distance_metric": "cosine", - "vector_store": "vs_test", - "alias": "test_alias", - "index_type": "IVF", - "top_k": 10, - "search_type": "Similarity", - } - } - - # Set widget states - state.selected_vector_search_model = "openai/text-embed-3" - state.selected_vector_search_chunk_size = 1000 - state.selected_vector_search_chunk_overlap = 200 - state.selected_vector_search_distance_metric = "cosine" - state.selected_vector_search_alias = "test_alias" - state.selected_vector_search_index_type = "IVF" - - # Define and execute reset logic (simulating the reset callback) - fields_to_reset = [ - "model", - "chunk_size", - "chunk_overlap", - "distance_metric", - "vector_store", - "alias", - "index_type", - ] - for key in fields_to_reset: - widget_key = f"selected_vector_search_{key}" - state[widget_key] = "" - state.client_settings["vector_search"][key] = "" - - # Verify the correct fields were cleared - for field in fields_to_reset: - assert state.client_settings["vector_search"][field] == "" - assert state[f"selected_vector_search_{field}"] == "" - - # Verify other fields were NOT cleared - assert state.client_settings["vector_search"]["top_k"] == 10 - assert state.client_settings["vector_search"]["search_type"] == "Similarity" - - def test_reset_enables_auto_population(self, app_server): - """Test that reset creates conditions for auto-population""" - assert app_server is not None - - # Setup with existing values - state.client_settings = {"vector_search": {"alias": "existing"}} - state.selected_vector_search_alias = "existing" - - # Execute reset logic - state.selected_vector_search_alias = "" - state.client_settings["vector_search"]["alias"] = "" - - # After reset, fields should be empty (ready for auto-population) - assert state.client_settings["vector_search"]["alias"] == "" - assert state.selected_vector_search_alias == "" - - # Now when _vs_gen_selectbox is called with a single option, it should auto-populate - with patch.object(st.sidebar, "selectbox") as mock_selectbox: - mock_selectbox.return_value = "auto_selected" - - st_common._vs_gen_selectbox("Select Alias:", ["auto_selected"], "selected_vector_search_alias") - - # Verify auto-population happened - assert state.client_settings["vector_search"]["alias"] == "auto_selected" - assert state.selected_vector_search_alias == "auto_selected" diff --git a/tests/client/unit/utils/test_vs_options_unit.py b/tests/client/unit/utils/test_vs_options_unit.py new file mode 100644 index 00000000..856827f5 --- /dev/null +++ b/tests/client/unit/utils/test_vs_options_unit.py @@ -0,0 +1,315 @@ +# pylint: disable=protected-access,import-error,import-outside-toplevel +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker: disable + +from unittest.mock import MagicMock + +import pandas as pd +from streamlit import session_state as state + +from client.utils import vs_options + + +############################################################################# +# Test _get_vs_fields +############################################################################# +class TestGetVsFields: + """Test _get_vs_fields function""" + + def test_returns_correct_fields(self, app_server): + """Test that _get_vs_fields returns expected field tuples""" + assert app_server is not None + + fields = vs_options._get_vs_fields() + + assert len(fields) == 6 + assert ("Select Alias:", "alias") in fields + assert ("Select Model:", "model") in fields + assert ("Select Chunk Size:", "chunk_size") in fields + assert ("Select Chunk Overlap:", "chunk_overlap") in fields + assert ("Select Distance Metric:", "distance_metric") in fields + assert ("Select Index Type:", "index_type") in fields + + def test_fields_order_is_consistent(self, app_server): + """Test that _get_vs_fields returns fields in consistent order""" + assert app_server is not None + + fields = vs_options._get_vs_fields() + + # Verify order: alias, model, chunk_size, chunk_overlap, distance_metric, index_type + assert fields[0][1] == "alias" + assert fields[1][1] == "model" + assert fields[2][1] == "chunk_size" + assert fields[3][1] == "chunk_overlap" + assert fields[4][1] == "distance_metric" + assert fields[5][1] == "index_type" + + +############################################################################# +# Test _get_valid_options +############################################################################# +class TestGetValidOptions: + """Test _get_valid_options function""" + + def test_no_filters_returns_all(self, app_server, sample_vector_stores_list): + """Test _get_valid_options with no filters returns all unique values""" + assert app_server is not None + + vs_df = pd.DataFrame(sample_vector_stores_list) + selections = {} + + result = vs_options._get_valid_options(vs_df, "alias", selections) + + assert len(result) == 2 + assert "vs1" in result + assert "vs2" in result + + def test_with_filter_returns_filtered(self, app_server, sample_vector_stores_list): + """Test _get_valid_options filters by other selections""" + assert app_server is not None + + vs_df = pd.DataFrame(sample_vector_stores_list) + # Filter by chunk_size=1000 which should only match vs1 + selections = {"chunk_size": 1000} + + result = vs_options._get_valid_options(vs_df, "alias", selections) + + assert len(result) == 1 + assert "vs1" in result + + def test_excludes_empty_strings(self, app_server): + """Test _get_valid_options excludes empty strings from results""" + assert app_server is not None + + vs_df = pd.DataFrame([ + {"alias": "vs1", "model": "openai/text-embed-3"}, + {"alias": "", "model": "openai/text-embed-3"}, + ]) + selections = {} + + result = vs_options._get_valid_options(vs_df, "alias", selections) + + assert "" not in result + assert "vs1" in result + + +############################################################################# +# Test _auto_select +############################################################################# +class TestAutoSelect: + """Test _auto_select function""" + + def test_single_option_auto_selects(self, app_server): + """Test auto-selection when there's only one valid option""" + assert app_server is not None + + vs_df = pd.DataFrame([ + {"alias": "only_one", "model": "openai/text-embed-3", "chunk_size": 1000, + "chunk_overlap": 200, "distance_metric": "cosine", "index_type": "IVF"}, + ]) + selections = {"alias": "", "model": "", "chunk_size": "", "chunk_overlap": "", + "distance_metric": "", "index_type": ""} + + result = vs_options._auto_select(vs_df, selections) + + # All fields should be auto-selected since there's only one option + assert result["alias"] == "only_one" + assert result["model"] == "openai/text-embed-3" + + def test_multiple_options_no_auto_select(self, app_server, sample_vector_stores_list): + """Test no auto-selection when multiple options exist""" + assert app_server is not None + + vs_df = pd.DataFrame(sample_vector_stores_list) + selections = {"alias": "", "model": "", "chunk_size": "", "chunk_overlap": "", + "distance_metric": "", "index_type": ""} + + result = vs_options._auto_select(vs_df, selections) + + # Model should be auto-selected (same for both), but alias should not + assert result["model"] == "openai/text-embed-3" # Same for both + assert result["alias"] == "" # Multiple options, no auto-select + + def test_invalid_selection_cleared(self, app_server, sample_vector_stores_list): + """Test that invalid selections are cleared""" + assert app_server is not None + + vs_df = pd.DataFrame(sample_vector_stores_list) + # Set an invalid alias that doesn't exist in the dataframe + selections = {"alias": "invalid_alias", "model": "", "chunk_size": "", + "chunk_overlap": "", "distance_metric": "", "index_type": ""} + + result = vs_options._auto_select(vs_df, selections) + + # Invalid alias should be cleared + assert result["alias"] == "" + + +############################################################################# +# Test _reset_selections +############################################################################# +class TestResetSelections: + """Test _reset_selections function""" + + def test_reset_clears_all_vs_fields(self, app_server): + """Test reset clears all vector store selection fields""" + assert app_server is not None + + # Setup initial values + state.client_settings = { + "vector_search": { + "model": "openai/text-embed-3", + "chunk_size": 1000, + "chunk_overlap": 200, + "distance_metric": "cosine", + "vector_store": "vs_test", + "alias": "test_alias", + "index_type": "IVF", + "top_k": 10, + "search_type": "Similarity", + } + } + state["_vs_key_version"] = 0 + + vs_options._reset_selections() + + # Verify VS selection fields were cleared + assert state.client_settings["vector_search"]["alias"] == "" + assert state.client_settings["vector_search"]["model"] == "" + assert state.client_settings["vector_search"]["chunk_size"] == "" + assert state.client_settings["vector_search"]["chunk_overlap"] == "" + assert state.client_settings["vector_search"]["distance_metric"] == "" + assert state.client_settings["vector_search"]["index_type"] == "" + assert state.client_settings["vector_search"]["vector_store"] == "" + + # Verify other fields were NOT cleared + assert state.client_settings["vector_search"]["top_k"] == 10 + assert state.client_settings["vector_search"]["search_type"] == "Similarity" + + def test_reset_increments_key_version(self, app_server): + """Test reset increments the key version for widget reset""" + assert app_server is not None + + state.client_settings = {"vector_search": { + "alias": "", "model": "", "chunk_size": "", "chunk_overlap": "", + "distance_metric": "", "index_type": "", "vector_store": "" + }} + state["_vs_key_version"] = 5 + + vs_options._reset_selections() + + assert state["_vs_key_version"] == 6 + + +############################################################################# +# Test _get_current_selections +############################################################################# +class TestGetCurrentSelections: + """Test _get_current_selections function""" + + def test_gets_from_widget_state(self, app_server): + """Test getting selections from widget state""" + assert app_server is not None + + state.client_settings = {"vector_search": {"alias": "from_settings"}} + state["vs_alias_0"] = "from_widget" + + result = vs_options._get_current_selections(key_version=0) + + assert result["alias"] == "from_widget" + + def test_falls_back_to_client_settings(self, app_server): + """Test fallback to client_settings when widget not in state""" + assert app_server is not None + + # Use a different key_version to avoid state from previous test + state.client_settings = {"vector_search": {"alias": "from_settings", "model": "test_model"}} + # No widget state set for this key_version + + result = vs_options._get_current_selections(key_version=99) + + assert result["alias"] == "from_settings" + assert result["model"] == "test_model" + + +############################################################################# +# Test _render_selectbox +############################################################################# +class TestRenderSelectbox: + """Test _render_selectbox function""" + + def test_selectbox_disabled_when_no_options(self, app_server): + """Test selectbox is disabled when no valid options""" + assert app_server is not None + + mock_container = MagicMock() + mock_container.selectbox = MagicMock(return_value="") + base_df = pd.DataFrame(columns=["alias", "model"]) + current_selections = {"alias": ""} + + vs_options._render_selectbox( + mock_container, "Select Alias:", "alias", base_df, current_selections, key_version=0 + ) + + # Verify selectbox was called with disabled=True + mock_container.selectbox.assert_called_once() + call_kwargs = mock_container.selectbox.call_args[1] + assert call_kwargs["disabled"] is True + + def test_selectbox_enabled_with_options(self, app_server, sample_vector_stores_list): + """Test selectbox is enabled when valid options exist""" + assert app_server is not None + + mock_container = MagicMock() + mock_container.selectbox = MagicMock(return_value="vs1") + base_df = pd.DataFrame(sample_vector_stores_list) + current_selections = {"alias": ""} + + vs_options._render_selectbox( + mock_container, "Select Alias:", "alias", base_df, current_selections, key_version=0 + ) + + # Verify selectbox was called with disabled=False + mock_container.selectbox.assert_called_once() + call_kwargs = mock_container.selectbox.call_args[1] + assert call_kwargs["disabled"] is False + + def test_selectbox_preserves_valid_selection(self, app_server, sample_vector_stores_list): + """Test selectbox preserves current selection if valid""" + assert app_server is not None + + mock_container = MagicMock() + mock_container.selectbox = MagicMock(return_value="vs1") + base_df = pd.DataFrame(sample_vector_stores_list) + current_selections = {"alias": "vs1"} + + vs_options._render_selectbox( + mock_container, "Select Alias:", "alias", base_df, current_selections, key_version=0 + ) + + # Verify selectbox was called with correct index for "vs1" + call_kwargs = mock_container.selectbox.call_args[1] + options = call_kwargs["options"] + assert "vs1" in options + # Index should point to vs1, not empty string + assert call_kwargs["index"] == options.index("vs1") + + def test_selectbox_resets_invalid_selection(self, app_server, sample_vector_stores_list): + """Test selectbox resets to empty when current selection is invalid""" + assert app_server is not None + + mock_container = MagicMock() + mock_container.selectbox = MagicMock(return_value="") + base_df = pd.DataFrame(sample_vector_stores_list) + current_selections = {"alias": "invalid_value"} + + vs_options._render_selectbox( + mock_container, "Select Alias:", "alias", base_df, current_selections, key_version=0 + ) + + # Verify selectbox was called with index 0 (empty option) + call_kwargs = mock_container.selectbox.call_args[1] + assert call_kwargs["index"] == 0 diff --git a/tests/server/unit/api/utils/test_utils_chat.py b/tests/server/unit/api/utils/test_utils_chat.py index 24abbfa5..b614a00e 100644 --- a/tests/server/unit/api/utils/test_utils_chat.py +++ b/tests/server/unit/api/utils/test_utils_chat.py @@ -49,7 +49,8 @@ def sample_client_settings(self): @patch("server.agents.chatbot.chatbot_graph.astream") @pytest.mark.asyncio async def test_completion_generator_success( - self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client, sample_request, sample_client_settings + self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client, + sample_request, sample_client_settings ): """Test successful completion generation""" # Setup mocks @@ -84,7 +85,8 @@ async def mock_generator(): @patch("server.agents.chatbot.chatbot_graph.astream") @pytest.mark.asyncio async def test_completion_generator_streaming( - self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client, sample_request, sample_client_settings + self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client, + sample_request, sample_client_settings ): """Test streaming completion generation""" # Setup mocks @@ -166,7 +168,8 @@ async def mock_generator(): @patch("server.agents.chatbot.chatbot_graph.astream") @pytest.mark.asyncio async def test_completion_generator_no_model_specified( - self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client, sample_message, sample_client_settings + self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client, + sample_message, sample_client_settings ): """Test completion generation when no model is specified in request""" # Create request without model diff --git a/tests/server/unit/api/utils/test_utils_databases_crud.py b/tests/server/unit/api/utils/test_utils_databases_crud.py index f0c9af60..6ad67d00 100644 --- a/tests/server/unit/api/utils/test_utils_databases_crud.py +++ b/tests/server/unit/api/utils/test_utils_databases_crud.py @@ -17,6 +17,9 @@ class TestDatabases: """Test databases module functionality""" + sample_database: Database + sample_database_2: Database + def setup_method(self): """Setup test data before each test""" self.sample_database = Database(name="test_db", user="test_user", password="test_password", dsn="test_dsn") diff --git a/tests/server/unit/api/utils/test_utils_databases_functions.py b/tests/server/unit/api/utils/test_utils_databases_functions.py index eb2a6bf3..42736539 100644 --- a/tests/server/unit/api/utils/test_utils_databases_functions.py +++ b/tests/server/unit/api/utils/test_utils_databases_functions.py @@ -20,6 +20,8 @@ class TestDatabaseUtilsPrivateFunctions: """Test private utility functions""" + sample_database: Database + def setup_method(self): """Setup test data""" self.sample_database = Database( @@ -168,6 +170,8 @@ def test_get_vs_malformed_json(self, mock_execute_sql): class TestDatabaseUtilsPublicFunctions: """Test public utility functions - connection and execution""" + sample_database: Database + def setup_method(self): """Setup test data""" self.sample_database = Database( @@ -437,6 +441,8 @@ def test_drop_vs_calls_langchain(self, mock_drop_table): class TestDatabaseUtilsQueryFunctions: """Test public utility functions - get and client database functions""" + sample_database: Database + def setup_method(self): """Setup test data""" self.sample_database = Database( From 9b10be46af90e45a67ac8dbec5f46e639e0cd10e Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 08:46:49 +0000 Subject: [PATCH 05/20] Add description to vector stores --- src/client/content/tools/tabs/split_embed.py | 38 +++++++++----------- src/server/api/utils/embed.py | 16 ++++++--- src/server/api/v1/embed.py | 17 +++++++++ 3 files changed, 45 insertions(+), 26 deletions(-) diff --git a/src/client/content/tools/tabs/split_embed.py b/src/client/content/tools/tabs/split_embed.py index f67085b2..ff44f0ee 100644 --- a/src/client/content/tools/tabs/split_embed.py +++ b/src/client/content/tools/tabs/split_embed.py @@ -295,7 +295,7 @@ def _display_file_list_expander(file_list_response: dict) -> None: # Build expander title total_files = file_list_response["total_files"] total_chunks = file_list_response["total_chunks"] - expander_title = f"📁 Exiting Embeddings ({total_files} files, {total_chunks} chunks)" + expander_title = f"📁 Existing Embeddings ({total_files} files, {total_chunks} chunks)" orphaned = file_list_response.get("orphaned_chunks", 0) if orphaned > 0: expander_title += f" ⚠️ {orphaned} orphaned" @@ -347,7 +347,6 @@ def _validate_new_alias(alias: str) -> bool: """Validate a new vector store alias and display appropriate messages.""" alias_pattern = r"^[A-Za-z][A-Za-z0-9_]*$" if not alias: - st.warning("Please enter a Vector Store Alias to continue.") return True if not re.match(alias_pattern, alias): st.error( @@ -403,42 +402,39 @@ def _render_populate_vs_section( for store in db.get("vector_stores", []) ) if vs_exists: - st.caption("Vector store already exists. New chunks will be added to existing Vector Store.") + try: + file_list_response = api_call.get(endpoint=f"v1/embed/{embed_request.vector_store}/files") + if file_list_response and "files" in file_list_response: + _display_file_list_expander(file_list_response) + except api_call.ApiError as e: + logger.warning("Could not retrieve file list for %s: %s", embed_request.vector_store, e) else: - st.caption("New vector store will be created.") + st.caption("A new vector store will be created.") - # Get Description - st.markdown("**Vector Store Description (Provide a description to help the retriever find relevant tables):**") + # Vector Store Description + st.divider() col1, col2 = st.columns([4, 1]) with col1: embed_request.description = st.text_input( - "Vector Store Description:", + "Provide a description to help AI understand this purpose of this Vector Store:", max_chars=255, value=embed_request.description, - placeholder="Enter a description for the new vector store", - label_visibility="collapsed", + placeholder="Enter a description for the Vector Store.", + # label_visibility="collapsed", ) + col2.space("small") with col2: - if not create_new_vs and embed_request.description: + if not create_new_vs: if st.button( "Update Description", type="secondary", key="comment_update", - help="Update the description of an existing Vector Store.", + help="Update the description of the Vector Store.", ): _ = api_call.patch( endpoint="v1/embed/comment", payload={"json": embed_request.model_dump()}, toast=True ) - # Display files in existing vector store - if not create_new_vs and embed_request.vector_store: - try: - file_list_response = api_call.get(endpoint=f"v1/embed/{embed_request.vector_store}/files") - if file_list_response and "files" in file_list_response: - _display_file_list_expander(file_list_response) - except api_call.ApiError as e: - logger.warning("Could not retrieve file list for %s: %s", embed_request.vector_store, e) - # Always render rate limit input to ensure session state is initialized rate_size, _ = st.columns([0.28, 0.72]) rate_limit = rate_size.number_input( @@ -513,7 +509,7 @@ def _handle_vector_store_population( is_source_valid = source_data.is_valid() if not embed_request.alias and create_new_vs: - st.info("Please provide a Vector Store Alias.") + st.info("Please provide a Vector Store Alias.", icon="⚠️") refresh_clicked = False populate_clicked = False diff --git a/src/server/api/utils/embed.py b/src/server/api/utils/embed.py index 09a78e8a..5e5ed5a3 100644 --- a/src/server/api/utils/embed.py +++ b/src/server/api/utils/embed.py @@ -404,15 +404,20 @@ def _merge_and_index_vector_store( except Exception as ex: logger.error("Unable to create vector index: %s", ex) - # Comment the VS table + +########################################## +# Vector Store +########################################## +def update_vs_comment(vector_store: schema.DatabaseVectorStorage, db_details: schema.Database) -> None: + """Comment on Existing Vector Store""" + db_conn = utils_databases.connect(db_details) + _, store_comment = functions.get_vs_table(**vector_store.model_dump(exclude={"database", "vector_store"})) comment = f"COMMENT ON TABLE {vector_store.vector_store} IS 'GENAI: {store_comment}'" utils_databases.execute_sql(db_conn, comment) + utils_databases.disconnect(db_conn) -########################################## -# Vector Store -########################################## def populate_vs( vector_store: schema.DatabaseVectorStorage, db_details: schema.Database, @@ -433,7 +438,8 @@ def populate_vs( # Merge and index _merge_and_index_vector_store(db_conn, vector_store, vector_store_tmp, embed_client) - utils_databases.disconnect(db_conn) + # Comment the VS table + update_vs_comment(vector_store, db_details) ########################################## diff --git a/src/server/api/v1/embed.py b/src/server/api/v1/embed.py index 95f16479..c754eaca 100644 --- a/src/server/api/v1/embed.py +++ b/src/server/api/v1/embed.py @@ -80,6 +80,23 @@ async def embed_get_files( raise HTTPException(status_code=400, detail=f"Could not retrieve file list: {str(ex)}") from ex +@auth.patch( + "/comment", + description="Update existing Vector Store Comment.", +) +async def comment_vs( + request: schema.DatabaseVectorStorage, + client: schema.ClientIdType = Header(default="server"), +) -> Response: + """Update the comment on an existing Vector Store""" + logger.info("Received comment_vs - request: %s", request) + utils_embed.update_vs_comment( + vector_store=request, + db_details=utils_databases.get_client_database(client), + ) + return Response(content=json.dumps({"message": "Vector Store comment updated."}), media_type="application/json") + + @auth.post( "/sql/store", description="Store SQL field for Embedding.", From 08c313fb1a3f4eabb7c6493a29810e7a73e875c3 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 12:32:58 +0000 Subject: [PATCH 06/20] Refresh tests --- src/server/api/utils/databases.py | 4 +- test/__init__.py | 1 + test/conftest.py | 161 ++++ test/unit/__init__.py | 1 + test/unit/server/__init__.py | 1 + test/unit/server/api/__init__.py | 1 + test/unit/server/api/conftest.py | 323 +++++++ test/unit/server/api/utils/__init__.py | 1 + test/unit/server/api/utils/test_utils_chat.py | 312 +++++++ .../server/api/utils/test_utils_databases.py | 502 +++++++++++ .../unit/server/api/utils/test_utils_embed.py | 805 ++++++++++++++++++ test/unit/server/api/utils/test_utils_mcp.py | 192 +++++ .../server/api/utils/test_utils_models.py | 433 ++++++++++ test/unit/server/api/utils/test_utils_oci.py | 595 +++++++++++++ .../server/api/utils/test_utils_settings.py | 352 ++++++++ .../server/api/utils/test_utils_testbed.py | 324 +++++++ .../server/api/utils/test_utils_webscrape.py | 419 +++++++++ test/unit/server/api/v1/__init__.py | 1 + test/unit/server/api/v1/test_v1_chat.py | 258 ++++++ test/unit/server/api/v1/test_v1_databases.py | 184 ++++ test/unit/server/api/v1/test_v1_embed.py | 553 ++++++++++++ test/unit/server/api/v1/test_v1_mcp.py | 169 ++++ .../unit/server/api/v1/test_v1_mcp_prompts.py | 229 +++++ test/unit/server/api/v1/test_v1_models.py | 254 ++++++ test/unit/server/api/v1/test_v1_oci.py | 362 ++++++++ test/unit/server/api/v1/test_v1_probes.py | 129 +++ test/unit/server/api/v1/test_v1_settings.py | 326 +++++++ test/unit/server/api/v1/test_v1_testbed.py | 305 +++++++ 28 files changed, 7195 insertions(+), 2 deletions(-) create mode 100644 test/__init__.py create mode 100644 test/conftest.py create mode 100644 test/unit/__init__.py create mode 100644 test/unit/server/__init__.py create mode 100644 test/unit/server/api/__init__.py create mode 100644 test/unit/server/api/conftest.py create mode 100644 test/unit/server/api/utils/__init__.py create mode 100644 test/unit/server/api/utils/test_utils_chat.py create mode 100644 test/unit/server/api/utils/test_utils_databases.py create mode 100644 test/unit/server/api/utils/test_utils_embed.py create mode 100644 test/unit/server/api/utils/test_utils_mcp.py create mode 100644 test/unit/server/api/utils/test_utils_models.py create mode 100644 test/unit/server/api/utils/test_utils_oci.py create mode 100644 test/unit/server/api/utils/test_utils_settings.py create mode 100644 test/unit/server/api/utils/test_utils_testbed.py create mode 100644 test/unit/server/api/utils/test_utils_webscrape.py create mode 100644 test/unit/server/api/v1/__init__.py create mode 100644 test/unit/server/api/v1/test_v1_chat.py create mode 100644 test/unit/server/api/v1/test_v1_databases.py create mode 100644 test/unit/server/api/v1/test_v1_embed.py create mode 100644 test/unit/server/api/v1/test_v1_mcp.py create mode 100644 test/unit/server/api/v1/test_v1_mcp_prompts.py create mode 100644 test/unit/server/api/v1/test_v1_models.py create mode 100644 test/unit/server/api/v1/test_v1_oci.py create mode 100644 test/unit/server/api/v1/test_v1_probes.py create mode 100644 test/unit/server/api/v1/test_v1_settings.py create mode 100644 test/unit/server/api/v1/test_v1_testbed.py diff --git a/src/server/api/utils/databases.py b/src/server/api/utils/databases.py index 588cd369..8ef8735a 100644 --- a/src/server/api/utils/databases.py +++ b/src/server/api/utils/databases.py @@ -266,8 +266,8 @@ def get_client_database(client: ClientIdType, validate: bool = False) -> Databas # Get database name from client settings, defaulting to "DEFAULT" db_name = "DEFAULT" - if hasattr(client_settings, "vector_search") and client_settings.vector_search: - db_name = getattr(client_settings.vector_search, "database", "DEFAULT") + if hasattr(client_settings, "database") and client_settings.database: + db_name = getattr(client_settings.database, "alias", "DEFAULT") # Return Single the Database Object return get_databases(db_name=db_name, validate=validate) diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 00000000..66173aec --- /dev/null +++ b/test/__init__.py @@ -0,0 +1 @@ +# Test package diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 00000000..e409c62e --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,161 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Pytest fixtures for unit tests with real Oracle database. +Adapts the Docker container pattern from tests/conftest.py. +""" + +# pylint: disable=consider-using-with +# pylint: disable=redefined-outer-name +# Pytest fixtures use parameter injection where fixture names match parameters + +import time +import shutil +from pathlib import Path +from typing import Generator, Optional +from contextlib import contextmanager + +import pytest +import oracledb +import docker +from docker.errors import DockerException +from docker.models.containers import Container + + +# Test database configuration - matches tests/conftest.py +TEST_CONFIG = { + "db_username": "PYTEST", + "db_password": "OrA_41_3xPl0d3r", + "db_dsn": "//localhost:1525/FREEPDB1", +} + + +def wait_for_container_ready(container: Container, ready_output: str, since: Optional[int] = None) -> None: + """Wait for container to be ready by checking its logs with exponential backoff.""" + start_time = time.time() + retry_interval = 2 + + while time.time() - start_time < 120: # 2 minute timeout + try: + logs = container.logs(tail=100, since=since).decode("utf-8") + if ready_output in logs: + return + except DockerException as e: + container.remove(force=True) + raise DockerException(f"Failed to get container logs: {str(e)}") from e + + time.sleep(retry_interval) + retry_interval = min(retry_interval * 2, 10) # Exponential backoff, max 10 seconds + + container.remove(force=True) + raise TimeoutError("Container did not become ready within timeout") + + +@contextmanager +def temp_sql_setup(): + """Context manager for temporary SQL setup files.""" + temp_dir = Path("test/db_startup_temp") + try: + temp_dir.mkdir(exist_ok=True) + sql_content = f""" + alter system set vector_memory_size=512M scope=spfile; + + alter session set container=FREEPDB1; + CREATE TABLESPACE IF NOT EXISTS USERS DATAFILE '/opt/oracle/oradata/FREE/FREEPDB1/users_01.dbf' SIZE 100M; + CREATE USER IF NOT EXISTS "{TEST_CONFIG["db_username"]}" IDENTIFIED BY {TEST_CONFIG["db_password"]} + DEFAULT TABLESPACE "USERS" + TEMPORARY TABLESPACE "TEMP"; + GRANT "DB_DEVELOPER_ROLE" TO "{TEST_CONFIG["db_username"]}"; + ALTER USER "{TEST_CONFIG["db_username"]}" DEFAULT ROLE ALL; + ALTER USER "{TEST_CONFIG["db_username"]}" QUOTA UNLIMITED ON USERS; + + EXIT; + """ + + temp_sql_file = temp_dir / "01_db_user.sql" + temp_sql_file.write_text(sql_content, encoding="UTF-8") + yield temp_dir + finally: + if temp_dir.exists(): + shutil.rmtree(temp_dir) + + +@pytest.fixture(scope="session") +def db_container() -> Generator[Container, None, None]: + """Create and manage an Oracle database container for testing.""" + db_client = docker.from_env() + container = None + + try: + with temp_sql_setup() as temp_dir: + container = db_client.containers.run( + "container-registry.oracle.com/database/free:latest-lite", + environment={ + "ORACLE_PWD": TEST_CONFIG["db_password"], + "ORACLE_PDB": TEST_CONFIG["db_dsn"].rsplit("/", maxsplit=1)[-1], # FREEPDB1 + }, + ports={"1521/tcp": int(TEST_CONFIG["db_dsn"].split(":")[1].split("/")[0])}, # 1525 + volumes={str(temp_dir.absolute()): {"bind": "/opt/oracle/scripts/startup", "mode": "ro"}}, + detach=True, + ) + + # Wait for database to be ready + wait_for_container_ready(container, "DATABASE IS READY TO USE!") + + # Restart container to apply vector_memory_size + container.restart() + restart_time = int(time.time()) + wait_for_container_ready(container, "DATABASE IS READY TO USE!", since=restart_time) + + yield container + + except DockerException as e: + if container: + container.remove(force=True) + raise DockerException(f"Docker operation failed: {str(e)}") from e + + finally: + if container: + try: + container.stop(timeout=30) + container.remove() + except DockerException as e: + print(f"Warning: Failed to cleanup database container: {str(e)}") + + +@pytest.fixture(scope="session") +def db_connection(db_container) -> Generator[oracledb.Connection, None, None]: + """Session-scoped real Oracle database connection. + + Depends on db_container to ensure database is running. + Fails explicitly if connection cannot be established. + """ + # pylint: disable=unused-argument + conn = oracledb.connect( + user=TEST_CONFIG["db_username"], + password=TEST_CONFIG["db_password"], + dsn=TEST_CONFIG["db_dsn"], + ) + yield conn + conn.close() + + +@pytest.fixture +def db_transaction(db_connection) -> Generator[oracledb.Connection, None, None]: + """Transaction isolation for each test using savepoints. + + Creates a savepoint before each test and rolls back after, + ensuring tests don't affect each other's database state. + + Note: This is NOT autouse - tests must explicitly request it + to get transaction isolation. This allows tests that don't + need database access to run without the overhead. + """ + cursor = db_connection.cursor() + cursor.execute("SAVEPOINT test_savepoint") + + yield db_connection + + cursor.execute("ROLLBACK TO SAVEPOINT test_savepoint") + cursor.close() diff --git a/test/unit/__init__.py b/test/unit/__init__.py new file mode 100644 index 00000000..06825972 --- /dev/null +++ b/test/unit/__init__.py @@ -0,0 +1 @@ +# Unit test package diff --git a/test/unit/server/__init__.py b/test/unit/server/__init__.py new file mode 100644 index 00000000..bc4d60b5 --- /dev/null +++ b/test/unit/server/__init__.py @@ -0,0 +1 @@ +# Server unit test package diff --git a/test/unit/server/api/__init__.py b/test/unit/server/api/__init__.py new file mode 100644 index 00000000..b4333d68 --- /dev/null +++ b/test/unit/server/api/__init__.py @@ -0,0 +1 @@ +# API unit test package diff --git a/test/unit/server/api/conftest.py b/test/unit/server/api/conftest.py new file mode 100644 index 00000000..c1ba1493 --- /dev/null +++ b/test/unit/server/api/conftest.py @@ -0,0 +1,323 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Pytest fixtures for server/api unit tests. +Provides factory fixtures for creating test objects. +""" + +# pylint: disable=redefined-outer-name +# Pytest fixtures use parameter injection where fixture names match parameters + +from unittest.mock import MagicMock, AsyncMock +import pytest + +from common.schema import ( + Database, + DatabaseAuth, + Model, + OracleCloudSettings, + Settings, + LargeLanguageSettings, + DatabaseVectorStorage, + ChatRequest, + Configuration, +) + + +@pytest.fixture +def make_database(): + """Factory fixture to create Database objects.""" + + def _make_database( + name: str = "TEST_DB", + user: str = "test_user", + password: str = "test_password", + dsn: str = "localhost:1521/TESTPDB", + wallet_password: str = None, + **kwargs, + ) -> Database: + return Database( + name=name, + user=user, + password=password, + dsn=dsn, + wallet_password=wallet_password, + **kwargs, + ) + + return _make_database + + +@pytest.fixture +def make_model(): + """Factory fixture to create Model objects.""" + + def _make_model( + model_id: str = "gpt-4o-mini", + model_type: str = "ll", + provider: str = "openai", + enabled: bool = True, + **kwargs, + ) -> Model: + return Model( + id=model_id, + type=model_type, + provider=provider, + enabled=enabled, + **kwargs, + ) + + return _make_model + + +@pytest.fixture +def make_oci_config(): + """Factory fixture to create OracleCloudSettings objects.""" + + def _make_oci_config( + auth_profile: str = "DEFAULT", + genai_region: str = "us-ashburn-1", + **kwargs, + ) -> OracleCloudSettings: + return OracleCloudSettings( + auth_profile=auth_profile, + genai_region=genai_region, + **kwargs, + ) + + return _make_oci_config + + +@pytest.fixture +def make_ll_settings(): + """Factory fixture to create LargeLanguageSettings objects.""" + + def _make_ll_settings( + model: str = "gpt-4o-mini", + temperature: float = 0.7, + max_tokens: int = 4096, + chat_history: bool = True, + **kwargs, + ) -> LargeLanguageSettings: + return LargeLanguageSettings( + model=model, + temperature=temperature, + max_tokens=max_tokens, + chat_history=chat_history, + **kwargs, + ) + + return _make_ll_settings + + +@pytest.fixture +def make_settings(make_ll_settings): + """Factory fixture to create Settings objects.""" + + def _make_settings( + client: str = "test_client", + ll_model: LargeLanguageSettings = None, + **kwargs, + ) -> Settings: + if ll_model is None: + ll_model = make_ll_settings() + return Settings( + client=client, + ll_model=ll_model, + **kwargs, + ) + + return _make_settings + + +@pytest.fixture +def make_database_auth(): + """Factory fixture to create DatabaseAuth objects.""" + + def _make_database_auth( + user: str = "test_user", + password: str = "test_password", + dsn: str = "localhost:1521/TESTPDB", + wallet_password: str = None, + **kwargs, + ) -> DatabaseAuth: + return DatabaseAuth( + user=user, + password=password, + dsn=dsn, + wallet_password=wallet_password, + **kwargs, + ) + + return _make_database_auth + + +@pytest.fixture +def make_vector_store(): + """Factory fixture to create DatabaseVectorStorage objects.""" + + def _make_vector_store( + vector_store: str = "VS_TEST", + model: str = "text-embedding-3-small", + chunk_size: int = 1000, + chunk_overlap: int = 200, + **kwargs, + ) -> DatabaseVectorStorage: + return DatabaseVectorStorage( + vector_store=vector_store, + model=model, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + **kwargs, + ) + + return _make_vector_store + + +@pytest.fixture +def make_chat_request(): + """Factory fixture to create ChatRequest objects.""" + + def _make_chat_request( + content: str = "Hello", + role: str = "user", + **kwargs, + ) -> ChatRequest: + return ChatRequest( + messages=[{"role": role, "content": content}], + **kwargs, + ) + + return _make_chat_request + + +@pytest.fixture +def make_mcp_prompt(): + """Factory fixture to create MCP prompt mock objects.""" + + def _make_mcp_prompt( + name: str = "optimizer_test-prompt", + description: str = "Test prompt description", + text: str = "Test prompt text content", + ): + mock_prompt = MagicMock() + mock_prompt.name = name + mock_prompt.description = description + mock_prompt.text = text + mock_prompt.model_dump.return_value = { + "name": name, + "description": description, + "text": text, + } + return mock_prompt + + return _make_mcp_prompt + + +@pytest.fixture +def make_configuration(make_settings): + """Factory fixture to create Configuration objects.""" + + def _make_configuration( + client: str = "test_client", + client_settings: Settings = None, + **kwargs, + ) -> Configuration: + if client_settings is None: + client_settings = make_settings(client=client) + return Configuration( + client_settings=client_settings, + database_configs=[], + model_configs=[], + oci_configs=[], + prompt_configs=[], + **kwargs, + ) + + return _make_configuration + + +@pytest.fixture +def mock_fastmcp(): + """Create a mock FastMCP application.""" + mock_mcp = MagicMock() + mock_mcp.list_tools = AsyncMock(return_value=[]) + mock_mcp.list_resources = AsyncMock(return_value=[]) + mock_mcp.list_prompts = AsyncMock(return_value=[]) + return mock_mcp + + +@pytest.fixture +def mock_mcp_client(): + """Create a mock MCP client.""" + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.list_prompts = AsyncMock(return_value=[]) + mock_client.get_prompt = AsyncMock(return_value=MagicMock()) + mock_client.close = AsyncMock() + return mock_client + + +@pytest.fixture +def mock_db_connection(): + """Create a mock database connection for endpoint tests. + + This mock is used by v1 endpoint tests that mock the underlying + database utilities. It provides a simple MagicMock that can be + passed around without needing a real database connection. + + For tests that need actual database operations, use the real + db_connection or db_transaction fixtures from test/conftest.py. + """ + mock_conn = MagicMock() + mock_conn.cursor.return_value.__enter__ = MagicMock() + mock_conn.cursor.return_value.__exit__ = MagicMock() + mock_conn.commit = MagicMock() + mock_conn.rollback = MagicMock() + mock_conn.close = MagicMock() + return mock_conn + + +@pytest.fixture +def mock_request_app_state(mock_fastmcp): + """Create a mock FastAPI request with app state.""" + mock_request = MagicMock() + mock_request.app.state.fastmcp_app = mock_fastmcp + return mock_request + + +@pytest.fixture +def mock_bootstrap(): + """Create mocks for bootstrap module dependencies.""" + return { + "databases": [], + "models": [], + "oci_configs": [], + "prompts": [], + "settings": [], + } + + +def create_mock_aiohttp_session(mock_session_class, mock_response): + """Helper to create a mock aiohttp ClientSession with response. + + This is a shared utility for tests that need to mock aiohttp.ClientSession. + It properly sets up async context manager behavior for session.get(). + + Args: + mock_session_class: The patched aiohttp.ClientSession class + mock_response: The mock response object to return from session.get() + + Returns: + The configured mock session object + """ + mock_session = AsyncMock() + mock_session.get = MagicMock( + return_value=AsyncMock(__aenter__=AsyncMock(return_value=mock_response)) + ) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock() + mock_session_class.return_value = mock_session + return mock_session diff --git a/test/unit/server/api/utils/__init__.py b/test/unit/server/api/utils/__init__.py new file mode 100644 index 00000000..9d9b7b29 --- /dev/null +++ b/test/unit/server/api/utils/__init__.py @@ -0,0 +1 @@ +# Utils unit test package diff --git a/test/unit/server/api/utils/test_utils_chat.py b/test/unit/server/api/utils/test_utils_chat.py new file mode 100644 index 00000000..22f952a6 --- /dev/null +++ b/test/unit/server/api/utils/test_utils_chat.py @@ -0,0 +1,312 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/utils/chat.py +Tests for chat completion utility functions. +""" + +from unittest.mock import patch, MagicMock +import pytest + +from server.api.utils import chat as utils_chat +from server.api.utils.models import UnknownModelError +from common.schema import ChatRequest + + +class TestCompletionGenerator: + """Tests for the completion_generator function.""" + + @pytest.mark.asyncio + @patch("server.api.utils.chat.utils_settings.get_client") + @patch("server.api.utils.chat.utils_oci.get") + @patch("server.api.utils.chat.utils_models.get_litellm_config") + @patch("server.api.utils.chat.chatbot_graph") + async def test_completion_generator_completions_mode( + self, + mock_graph, + mock_get_config, + mock_oci_get, + mock_get_client, + make_settings, + make_chat_request, + make_oci_config, + ): + """completion_generator should yield final response in completions mode.""" + mock_get_client.return_value = make_settings() + mock_oci_get.return_value = make_oci_config() + mock_get_config.return_value = {"model": "gpt-4o-mini"} + + async def mock_astream(**_kwargs): + yield {"completion": {"choices": [{"message": {"content": "Hello!"}}]}} + + mock_graph.astream = mock_astream + + request = make_chat_request(content="Hi") + results = [] + async for output in utils_chat.completion_generator("test_client", request, "completions"): + results.append(output) + + assert len(results) == 1 + assert results[0]["choices"][0]["message"]["content"] == "Hello!" + + @pytest.mark.asyncio + @patch("server.api.utils.chat.utils_settings.get_client") + @patch("server.api.utils.chat.utils_oci.get") + @patch("server.api.utils.chat.utils_models.get_litellm_config") + @patch("server.api.utils.chat.chatbot_graph") + async def test_completion_generator_streams_mode( + self, + mock_graph, + mock_get_config, + mock_oci_get, + mock_get_client, + make_settings, + make_chat_request, + make_oci_config, + ): + """completion_generator should yield stream chunks in streams mode.""" + mock_get_client.return_value = make_settings() + mock_oci_get.return_value = make_oci_config() + mock_get_config.return_value = {"model": "gpt-4o-mini"} + + async def mock_astream(**_kwargs): + yield {"stream": "Hello"} + yield {"stream": " World"} + yield {"completion": {"choices": []}} + + mock_graph.astream = mock_astream + + request = make_chat_request(content="Hi") + results = [] + async for output in utils_chat.completion_generator("test_client", request, "streams"): + results.append(output) + + # Should have 3 outputs: 2 stream chunks + stream_finished + assert len(results) == 3 + assert results[0] == b"Hello" + assert results[1] == b" World" + assert results[2] == "[stream_finished]" + + @pytest.mark.asyncio + @patch("server.api.utils.chat.utils_settings.get_client") + @patch("server.api.utils.chat.utils_oci.get") + @patch("server.api.utils.chat.utils_models.get_litellm_config") + @patch("server.api.utils.chat.completion") + async def test_completion_generator_unknown_model_error( + self, + mock_completion, + mock_get_config, + mock_oci_get, + mock_get_client, + make_settings, + make_chat_request, + make_oci_config, + ): + """completion_generator should return error response on UnknownModelError.""" + mock_get_client.return_value = make_settings() + mock_oci_get.return_value = make_oci_config() + mock_get_config.side_effect = UnknownModelError("Model not found") + + mock_error_response = MagicMock() + mock_error_response.choices = [MagicMock()] + mock_error_response.choices[0].message.content = "I'm unable to initialise the Language Model." + mock_completion.return_value = mock_error_response + + request = make_chat_request(content="Hi") + results = [] + async for output in utils_chat.completion_generator("test_client", request, "completions"): + results.append(output) + + assert len(results) == 1 + mock_completion.assert_called_once() + # Verify mock_response was used + call_kwargs = mock_completion.call_args.kwargs + assert "mock_response" in call_kwargs + + @pytest.mark.asyncio + @patch("server.api.utils.chat.utils_settings.get_client") + @patch("server.api.utils.chat.utils_oci.get") + @patch("server.api.utils.chat.utils_models.get_litellm_config") + @patch("server.api.utils.chat.chatbot_graph") + async def test_completion_generator_uses_request_model( + self, mock_graph, mock_get_config, mock_oci_get, mock_get_client, make_settings, make_oci_config + ): + """completion_generator should use model from request if provided.""" + mock_get_client.return_value = make_settings() + mock_oci_get.return_value = make_oci_config() + mock_get_config.return_value = {"model": "claude-3"} + + async def mock_astream(**_kwargs): + yield {"completion": {}} + + mock_graph.astream = mock_astream + + request = ChatRequest(messages=[{"role": "user", "content": "Hi"}], model="claude-3") + async for _ in utils_chat.completion_generator("test_client", request, "completions"): + pass + + # get_litellm_config should be called with the request model + call_args = mock_get_config.call_args[0] + assert call_args[0]["model"] == "claude-3" + + @pytest.mark.asyncio + @patch("server.api.utils.chat.utils_settings.get_client") + @patch("server.api.utils.chat.utils_oci.get") + @patch("server.api.utils.chat.utils_models.get_litellm_config") + @patch("server.api.utils.chat.chatbot_graph") + async def test_completion_generator_uses_settings_model_when_not_in_request( + self, + mock_graph, + mock_get_config, + mock_oci_get, + mock_get_client, + make_settings, + make_chat_request, + make_oci_config, + make_ll_settings, + ): + """completion_generator should use model from settings when not in request.""" + settings = make_settings(ll_model=make_ll_settings(model="gpt-4-turbo")) + mock_get_client.return_value = settings + mock_oci_get.return_value = make_oci_config() + mock_get_config.return_value = {"model": "gpt-4-turbo"} + + async def mock_astream(**_kwargs): + yield {"completion": {}} + + mock_graph.astream = mock_astream + + request = make_chat_request(content="Hi") # No model specified + async for _ in utils_chat.completion_generator("test_client", request, "completions"): + pass + + # get_litellm_config should be called with settings model + call_args = mock_get_config.call_args[0] + assert call_args[0]["model"] == "gpt-4-turbo" + + @pytest.mark.asyncio + @patch("server.api.utils.chat.utils_settings.get_client") + @patch("server.api.utils.chat.utils_oci.get") + @patch("server.api.utils.chat.utils_models.get_litellm_config") + @patch("server.api.utils.chat.utils_databases.get_client_database") + @patch("server.api.utils.chat.utils_models.get_client_embed") + @patch("server.api.utils.chat.chatbot_graph") + async def test_completion_generator_with_vector_search_enabled( + self, + mock_graph, + mock_get_embed, + mock_get_db, + mock_get_config, + mock_oci_get, + mock_get_client, + make_settings, + make_chat_request, + make_oci_config, + ): + """completion_generator should setup db connection when vector search enabled.""" + settings = make_settings() + settings.tools_enabled = ["Vector Search"] + mock_get_client.return_value = settings + mock_oci_get.return_value = make_oci_config() + mock_get_config.return_value = {"model": "gpt-4o-mini"} + + mock_db = MagicMock() + mock_db.connection = MagicMock() + mock_get_db.return_value = mock_db + mock_get_embed.return_value = MagicMock() + + async def mock_astream(**_kwargs): + yield {"completion": {}} + + mock_graph.astream = mock_astream + + request = make_chat_request(content="Hi") + async for _ in utils_chat.completion_generator("test_client", request, "completions"): + pass + + mock_get_db.assert_called_once_with("test_client", False) + mock_get_embed.assert_called_once() + + @pytest.mark.asyncio + @patch("server.api.utils.chat.utils_settings.get_client") + @patch("server.api.utils.chat.utils_oci.get") + @patch("server.api.utils.chat.utils_models.get_litellm_config") + @patch("server.api.utils.chat.chatbot_graph") + async def test_completion_generator_passes_correct_config( + self, + mock_graph, + mock_get_config, + mock_oci_get, + mock_get_client, + make_settings, + make_chat_request, + make_oci_config, + ): + """completion_generator should pass correct config to chatbot_graph.""" + settings = make_settings() + mock_get_client.return_value = settings + mock_oci_get.return_value = make_oci_config() + mock_get_config.return_value = {"model": "gpt-4o-mini"} + + captured_kwargs = {} + + async def mock_astream(**kwargs): + captured_kwargs.update(kwargs) + yield {"completion": {}} + + mock_graph.astream = mock_astream + + request = make_chat_request(content="Test message") + async for _ in utils_chat.completion_generator("test_client", request, "completions"): + pass + + assert captured_kwargs["stream_mode"] == "custom" + assert captured_kwargs["config"]["configurable"]["thread_id"] == "test_client" + assert captured_kwargs["config"]["metadata"]["streaming"] is False + + @pytest.mark.asyncio + @patch("server.api.utils.chat.utils_settings.get_client") + @patch("server.api.utils.chat.utils_oci.get") + @patch("server.api.utils.chat.utils_models.get_litellm_config") + @patch("server.api.utils.chat.chatbot_graph") + async def test_completion_generator_streaming_metadata( + self, + mock_graph, + mock_get_config, + mock_oci_get, + mock_get_client, + make_settings, + make_chat_request, + make_oci_config, + ): + """completion_generator should set streaming=True for streams mode.""" + mock_get_client.return_value = make_settings() + mock_oci_get.return_value = make_oci_config() + mock_get_config.return_value = {"model": "gpt-4o-mini"} + + captured_kwargs = {} + + async def mock_astream(**kwargs): + captured_kwargs.update(kwargs) + yield {"completion": {}} + + mock_graph.astream = mock_astream + + request = make_chat_request(content="Test") + async for _ in utils_chat.completion_generator("test_client", request, "streams"): + pass + + assert captured_kwargs["config"]["metadata"]["streaming"] is True + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured.""" + assert hasattr(utils_chat, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert utils_chat.logger.name == "api.utils.chat" diff --git a/test/unit/server/api/utils/test_utils_databases.py b/test/unit/server/api/utils/test_utils_databases.py new file mode 100644 index 00000000..4b67062f --- /dev/null +++ b/test/unit/server/api/utils/test_utils_databases.py @@ -0,0 +1,502 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/utils/databases.py +Tests for database utility functions. + +Uses hybrid approach: +- Real Oracle database for connection/SQL execution tests +- Mocks for pure Python logic tests (in-memory operations, exception handling) +""" + +# pylint: disable=too-few-public-methods + +from test.conftest import TEST_CONFIG +from unittest.mock import patch, MagicMock + +import pytest +import oracledb + +from common.schema import DatabaseSettings +from server.api.utils import databases as utils_databases +from server.api.utils.databases import DbException, ExistsDatabaseError, UnknownDatabaseError + + +class TestDbException: + """Tests for DbException class.""" + + def test_db_exception_init(self): + """DbException should store status_code and detail.""" + exc = DbException(status_code=404, detail="Not found") + assert exc.status_code == 404 + assert exc.detail == "Not found" + + def test_db_exception_message(self): + """DbException should use detail as message.""" + exc = DbException(status_code=500, detail="Server error") + assert str(exc) == "Server error" + + +class TestExistsDatabaseError: + """Tests for ExistsDatabaseError class.""" + + def test_exists_database_error_is_value_error(self): + """ExistsDatabaseError should inherit from ValueError.""" + exc = ExistsDatabaseError("Database exists") + assert isinstance(exc, ValueError) + + +class TestUnknownDatabaseError: + """Tests for UnknownDatabaseError class.""" + + def test_unknown_database_error_is_value_error(self): + """UnknownDatabaseError should inherit from ValueError.""" + exc = UnknownDatabaseError("Database not found") + assert isinstance(exc, ValueError) + + +class TestCreate: + """Tests for the create function.""" + + @patch("server.api.utils.databases.get") + @patch("server.api.utils.databases.DATABASE_OBJECTS", []) + def test_create_success(self, mock_get, make_database): + """create should add database to DATABASE_OBJECTS.""" + mock_get.side_effect = [UnknownDatabaseError("Not found"), [make_database()]] + database = make_database(name="NEW_DB") + + result = utils_databases.create(database) + + assert result is not None + + @patch("server.api.utils.databases.get") + def test_create_raises_exists_error(self, mock_get, make_database): + """create should raise ExistsDatabaseError if database exists.""" + mock_get.return_value = [make_database(name="EXISTING_DB")] + database = make_database(name="EXISTING_DB") + + with pytest.raises(ExistsDatabaseError): + utils_databases.create(database) + + @patch("server.api.utils.databases.get") + def test_create_raises_value_error_missing_fields(self, mock_get, make_database): + """create should raise ValueError if required fields missing.""" + mock_get.side_effect = UnknownDatabaseError("Not found") + database = make_database(user=None) + + with pytest.raises(ValueError) as exc_info: + utils_databases.create(database) + + assert "user" in str(exc_info.value) + + +class TestGet: + """Tests for the get function.""" + + @patch("server.api.utils.databases.DATABASE_OBJECTS") + def test_get_all_databases(self, mock_objects, make_database): + """get should return all databases when no name provided.""" + mock_objects.__iter__ = lambda _: iter([make_database(name="DB1"), make_database(name="DB2")]) + mock_objects.__len__ = lambda _: 2 + + result = utils_databases.get() + + assert len(result) == 2 + + @patch("server.api.utils.databases.DATABASE_OBJECTS") + def test_get_specific_database(self, mock_objects, make_database): + """get should return specific database when name provided.""" + db1 = make_database(name="DB1") + db2 = make_database(name="DB2") + mock_objects.__iter__ = lambda _: iter([db1, db2]) + mock_objects.__len__ = lambda _: 2 + + result = utils_databases.get(name="DB1") + + assert len(result) == 1 + assert result[0].name == "DB1" + + @patch("server.api.utils.databases.DATABASE_OBJECTS") + def test_get_raises_unknown_error(self, mock_objects): + """get should raise UnknownDatabaseError if name not found.""" + mock_objects.__iter__ = lambda _: iter([]) + mock_objects.__len__ = lambda _: 0 + + with pytest.raises(UnknownDatabaseError): + utils_databases.get(name="NONEXISTENT") + + +class TestDelete: + """Tests for the delete function.""" + + def test_delete_removes_database(self, make_database): + """delete should remove database from DATABASE_OBJECTS.""" + db1 = make_database(name="DB1") + db2 = make_database(name="DB2") + + with patch("server.api.utils.databases.DATABASE_OBJECTS", [db1, db2]) as mock_objects: + utils_databases.delete("DB1") + assert len(mock_objects) == 1 + assert mock_objects[0].name == "DB2" + + +class TestConnect: + """Tests for the connect function. + + Uses real database for success case, mocks for error code testing + (since we can't easily trigger specific Oracle errors). + """ + + def test_connect_success_real_db(self, db_container, make_database): + """connect should return connection on success (real database).""" + # pylint: disable=unused-argument + config = make_database( + user=TEST_CONFIG["db_username"], + password=TEST_CONFIG["db_password"], + dsn=TEST_CONFIG["db_dsn"], + ) + + result = utils_databases.connect(config) + + assert result is not None + assert result.is_healthy() + result.close() + + def test_connect_raises_value_error_missing_details(self, make_database): + """connect should raise ValueError if connection details missing.""" + config = make_database(user=None, password=None, dsn=None) + + with pytest.raises(ValueError) as exc_info: + utils_databases.connect(config) + + assert "missing connection details" in str(exc_info.value) + + def test_connect_raises_permission_error_invalid_credentials(self, db_container, make_database): + """connect should raise PermissionError on invalid credentials (real database).""" + # pylint: disable=unused-argument + config = make_database( + user="INVALID_USER", + password="wrong_password", + dsn=TEST_CONFIG["db_dsn"], + ) + + with pytest.raises(PermissionError): + utils_databases.connect(config) + + def test_connect_raises_connection_error_invalid_dsn(self, db_container, make_database): + """connect should raise ConnectionError on invalid service name (real database). + + Note: DPY-6005 (cannot connect) wraps DPY-6001 (service not registered), + and the current implementation maps DPY-6005 to ConnectionError. + """ + # pylint: disable=unused-argument + config = make_database( + user=TEST_CONFIG["db_username"], + password=TEST_CONFIG["db_password"], + dsn="//localhost:1525/NONEXISTENT_SERVICE", + ) + + with pytest.raises(ConnectionError): + utils_databases.connect(config) + + @patch("server.api.utils.databases.oracledb.connect") + def test_connect_raises_connection_error_on_oserror(self, mock_connect, make_database): + """connect should raise ConnectionError on OSError (mocked - can't easily trigger).""" + mock_connect.side_effect = OSError("Network unreachable") + config = make_database() + + with pytest.raises(ConnectionError): + utils_databases.connect(config) + + @patch("server.api.utils.databases.oracledb.connect") + def test_connect_wallet_location_defaults_to_config_dir(self, mock_connect, make_database): + """connect should default wallet_location to config_dir if not set (mocked - verifies call args).""" + mock_conn = MagicMock() + mock_connect.return_value = mock_conn + config = make_database(wallet_password="secret", config_dir="/path/to/config") + + utils_databases.connect(config) + + call_kwargs = mock_connect.call_args.kwargs + assert call_kwargs.get("wallet_location") == "/path/to/config" + + +class TestDisconnect: + """Tests for the disconnect function.""" + + def test_disconnect_closes_connection(self): + """disconnect should call close on connection.""" + mock_conn = MagicMock() + + utils_databases.disconnect(mock_conn) + + mock_conn.close.assert_called_once() + + +class TestExecuteSql: + """Tests for the execute_sql function. + + Uses real database for actual SQL execution tests. + """ + + def test_execute_sql_returns_rows(self, db_transaction): + """execute_sql should return query results (real database).""" + result = utils_databases.execute_sql(db_transaction, "SELECT 'val1' AS col1, 'val2' AS col2 FROM dual") + + assert len(result) == 1 + assert result[0] == ("val1", "val2") + + def test_execute_sql_with_binds(self, db_transaction): + """execute_sql should pass binds to cursor (real database).""" + result = utils_databases.execute_sql( + db_transaction, "SELECT :val AS result FROM dual", {"val": "test_value"} + ) + + assert result[0] == ("test_value",) + + def test_execute_sql_handles_clob_columns(self, db_transaction): + """execute_sql should read CLOB column values (real database).""" + # Create a CLOB using TO_CLOB function + result = utils_databases.execute_sql( + db_transaction, "SELECT TO_CLOB('CLOB content here') AS clob_col FROM dual" + ) + + # Result should have the CLOB content read as string + assert len(result) == 1 + assert "CLOB content here" in str(result[0]) + + def test_execute_sql_returns_dbms_output(self, db_transaction): + """execute_sql should return DBMS_OUTPUT when no rows (real database).""" + result = utils_databases.execute_sql( + db_transaction, + """ + BEGIN + DBMS_OUTPUT.ENABLE; + DBMS_OUTPUT.PUT_LINE('Test DBMS Output'); + END; + """, + ) + + assert "Test DBMS Output" in str(result) + + def test_execute_sql_multiple_rows(self, db_transaction): + """execute_sql should handle multiple rows (real database).""" + result = utils_databases.execute_sql( + db_transaction, + """ + SELECT LEVEL AS num FROM dual CONNECT BY LEVEL <= 3 + """, + ) + + assert len(result) == 3 + assert result[0] == (1,) + assert result[1] == (2,) + assert result[2] == (3,) + + +class TestDropVs: + """Tests for the drop_vs function.""" + + @patch("server.api.utils.databases.LangchainVS.drop_table_purge") + def test_drop_vs_calls_langchain(self, mock_drop): + """drop_vs should call LangchainVS.drop_table_purge.""" + mock_conn = MagicMock() + + utils_databases.drop_vs(mock_conn, "VS_TEST") + + mock_drop.assert_called_once_with(mock_conn, "VS_TEST") + + +class TestGetDatabases: + """Tests for the get_databases function.""" + + @patch("server.api.utils.databases.get") + def test_get_databases_without_name(self, mock_get, make_database): + """get_databases should return all databases without name.""" + mock_get.return_value = [make_database(name="DB1"), make_database(name="DB2")] + + result = utils_databases.get_databases() + + assert len(result) == 2 + + @patch("server.api.utils.databases.get") + def test_get_databases_with_name(self, mock_get, make_database): + """get_databases should return single database with name.""" + mock_get.return_value = [make_database(name="DB1")] + + result = utils_databases.get_databases(db_name="DB1") + + assert result.name == "DB1" + + @patch("server.api.utils.databases.get") + @patch("server.api.utils.databases.connect") + @patch("server.api.utils.databases._get_vs") + def test_get_databases_with_validate(self, mock_get_vs, mock_connect, mock_get, make_database): + """get_databases should validate connections when validate=True.""" + db = make_database(name="DB1") + mock_get.return_value = [db] + mock_conn = MagicMock() + mock_connect.return_value = mock_conn + mock_get_vs.return_value = [] + + result = utils_databases.get_databases(validate=True) + + mock_connect.assert_called_once() + assert result[0].connected is True + + @patch("server.api.utils.databases.get") + @patch("server.api.utils.databases.connect") + def test_get_databases_validate_handles_connection_error(self, mock_connect, mock_get, make_database): + """get_databases should continue on connection error during validation.""" + db = make_database(name="DB1") + mock_get.return_value = [db] + mock_connect.side_effect = ConnectionError("Cannot connect") + + result = utils_databases.get_databases(validate=True) + + assert len(result) == 1 + # Should not crash, just continue + + +class TestGetClientDatabase: + """Tests for the get_client_database function.""" + + @patch("server.api.utils.databases.utils_settings.get_client") + @patch("server.api.utils.databases.get_databases") + def test_get_client_database_default(self, mock_get_databases, mock_get_client, make_settings, make_database): + """get_client_database should default to DEFAULT database.""" + mock_get_client.return_value = make_settings() + mock_get_databases.return_value = make_database(name="DEFAULT") + + utils_databases.get_client_database("test_client") + + mock_get_databases.assert_called_once_with(db_name="DEFAULT", validate=False) + + @patch("server.api.utils.databases.utils_settings.get_client") + @patch("server.api.utils.databases.get_databases") + def test_get_client_database_from_database_settings( + self, mock_get_databases, mock_get_client, make_settings, make_database + ): + """get_client_database should use database alias from Settings.database.""" + settings = make_settings() + settings.database = DatabaseSettings(alias="CUSTOM_DB") + mock_get_client.return_value = settings + mock_get_databases.return_value = make_database(name="CUSTOM_DB") + + utils_databases.get_client_database("test_client") + + # Should use the alias from Settings.database + mock_get_databases.assert_called_once_with(db_name="CUSTOM_DB", validate=False) + + @patch("server.api.utils.databases.utils_settings.get_client") + @patch("server.api.utils.databases.get_databases") + def test_get_client_database_with_validate( + self, mock_get_databases, mock_get_client, make_settings, make_database + ): + """get_client_database should pass validate flag.""" + mock_get_client.return_value = make_settings() + mock_get_databases.return_value = make_database() + + utils_databases.get_client_database("test_client", validate=True) + + mock_get_databases.assert_called_once_with(db_name="DEFAULT", validate=True) + + +class TestTestConnection: # pylint: disable=protected-access + """Tests for the _test function.""" + + def test_test_connection_active(self, make_database): + """_test should set connected=True when ping succeeds.""" + config = make_database() + mock_conn = MagicMock() + mock_conn.ping.return_value = None + config.set_connection(mock_conn) + + utils_databases._test(config) + + assert config.connected is True + + @patch("server.api.utils.databases.connect") + def test_test_connection_refreshes_on_database_error(self, mock_connect, make_database): + """_test should refresh connection on DatabaseError.""" + config = make_database() + mock_conn = MagicMock() + mock_conn.ping.side_effect = oracledb.DatabaseError("Connection lost") + config.set_connection(mock_conn) + mock_connect.return_value = MagicMock() + + utils_databases._test(config) + + mock_connect.assert_called_once_with(config) + + def test_test_raises_db_exception_on_value_error(self, make_database): + """_test should raise DbException on ValueError.""" + config = make_database() + mock_conn = MagicMock() + mock_conn.ping.side_effect = ValueError("Invalid config") + config.set_connection(mock_conn) + + with pytest.raises(DbException) as exc_info: + utils_databases._test(config) + + assert exc_info.value.status_code == 400 + + def test_test_raises_db_exception_on_permission_error(self, make_database): + """_test should raise DbException on PermissionError.""" + config = make_database() + mock_conn = MagicMock() + mock_conn.ping.side_effect = PermissionError("Access denied") + config.set_connection(mock_conn) + + with pytest.raises(DbException) as exc_info: + utils_databases._test(config) + + assert exc_info.value.status_code == 401 + + def test_test_raises_db_exception_on_connection_error(self, make_database): + """_test should raise DbException on ConnectionError.""" + config = make_database() + mock_conn = MagicMock() + mock_conn.ping.side_effect = ConnectionError("Network error") + config.set_connection(mock_conn) + + with pytest.raises(DbException) as exc_info: + utils_databases._test(config) + + assert exc_info.value.status_code == 503 + + +class TestGetVs: # pylint: disable=protected-access + """Tests for the _get_vs function. + + Uses real database - queries user_tables for vector store metadata. + Note: Results depend on actual tables in test database schema. + """ + + def test_get_vs_returns_list(self, db_transaction): + """_get_vs should return a list (real database).""" + result = utils_databases._get_vs(db_transaction) + + # Should return a list (may be empty if no vector stores exist) + assert isinstance(result, list) + + def test_get_vs_empty_for_clean_schema(self, db_transaction): + """_get_vs should return empty list when no vector stores (real database).""" + # In a clean test schema, there should be no vector stores + result = utils_databases._get_vs(db_transaction) + + # Either empty or returns actual vector stores if they exist + assert isinstance(result, list) + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured.""" + assert hasattr(utils_databases, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert utils_databases.logger.name == "api.utils.database" diff --git a/test/unit/server/api/utils/test_utils_embed.py b/test/unit/server/api/utils/test_utils_embed.py new file mode 100644 index 00000000..dbaf2c4d --- /dev/null +++ b/test/unit/server/api/utils/test_utils_embed.py @@ -0,0 +1,805 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/utils/embed.py +Tests for document embedding and vector store utility functions. + +Uses hybrid approach: +- Real Oracle database for vector store query tests +- Mocks for file processing logic (document loaders, splitting, etc.) +""" + +# pylint: disable=too-few-public-methods + +import json +import os +from unittest.mock import patch, MagicMock +import pytest + +from langchain_core.documents import Document as LangchainDocument + +from server.api.utils import embed as utils_embed + + +class TestUpdateVsComment: + """Tests for the update_vs_comment function.""" + + @patch("server.api.utils.embed.utils_databases.connect") + @patch("server.api.utils.embed.functions.get_vs_table") + @patch("server.api.utils.embed.utils_databases.execute_sql") + @patch("server.api.utils.embed.utils_databases.disconnect") + def test_update_vs_comment_success( + self, mock_disconnect, mock_execute_sql, mock_get_vs_table, mock_connect, make_database, make_vector_store + ): + """update_vs_comment should execute comment SQL.""" + mock_conn = MagicMock() + mock_connect.return_value = mock_conn + mock_get_vs_table.return_value = ("VS_TEST", '{"alias": "test"}') + + db_details = make_database() + vector_store = make_vector_store(vector_store="VS_TEST") + + utils_embed.update_vs_comment(vector_store=vector_store, db_details=db_details) + + mock_connect.assert_called_once_with(db_details) + mock_execute_sql.assert_called_once() + mock_disconnect.assert_called_once_with(mock_conn) + + @patch("server.api.utils.embed.utils_databases.connect") + @patch("server.api.utils.embed.functions.get_vs_table") + @patch("server.api.utils.embed.utils_databases.execute_sql") + @patch("server.api.utils.embed.utils_databases.disconnect") + def test_update_vs_comment_builds_correct_sql( + self, _mock_disconnect, mock_execute_sql, mock_get_vs_table, mock_connect, make_database, make_vector_store + ): + """update_vs_comment should build correct COMMENT ON TABLE SQL.""" + mock_conn = MagicMock() + mock_connect.return_value = mock_conn + mock_get_vs_table.return_value = ("VS_MY_STORE", '{"alias": "my_alias", "model": "embed-3"}') + + db_details = make_database() + vector_store = make_vector_store(vector_store="VS_MY_STORE") + + utils_embed.update_vs_comment(vector_store=vector_store, db_details=db_details) + + call_args = mock_execute_sql.call_args[0] + sql = call_args[1] + assert "COMMENT ON TABLE VS_MY_STORE IS" in sql + assert "GENAI:" in sql + + @patch("server.api.utils.embed.utils_databases.connect") + @patch("server.api.utils.embed.functions.get_vs_table") + @patch("server.api.utils.embed.utils_databases.execute_sql") + @patch("server.api.utils.embed.utils_databases.disconnect") + def test_update_vs_comment_disconnects_on_success( + self, mock_disconnect, _mock_execute_sql, mock_get_vs_table, mock_connect, make_database, make_vector_store + ): + """update_vs_comment should disconnect from database after execution.""" + mock_conn = MagicMock() + mock_connect.return_value = mock_conn + mock_get_vs_table.return_value = ("VS_TEST", "{}") + + db_details = make_database() + vector_store = make_vector_store() + + utils_embed.update_vs_comment(vector_store=vector_store, db_details=db_details) + + mock_disconnect.assert_called_once_with(mock_conn) + + @patch("server.api.utils.embed.utils_databases.connect") + @patch("server.api.utils.embed.functions.get_vs_table") + @patch("server.api.utils.embed.utils_databases.execute_sql") + @patch("server.api.utils.embed.utils_databases.disconnect") + def test_update_vs_comment_calls_get_vs_table_with_correct_params( + self, _mock_disconnect, _mock_execute_sql, mock_get_vs_table, mock_connect, make_database, make_vector_store + ): + """update_vs_comment should call get_vs_table excluding database and vector_store.""" + mock_conn = MagicMock() + mock_connect.return_value = mock_conn + mock_get_vs_table.return_value = ("VS_TEST", "{}") + + db_details = make_database() + vector_store = make_vector_store( + vector_store="VS_TEST", + model="embed-model", + chunk_size=500, + chunk_overlap=100, + ) + + utils_embed.update_vs_comment(vector_store=vector_store, db_details=db_details) + + mock_get_vs_table.assert_called_once() + call_kwargs = mock_get_vs_table.call_args.kwargs + # Should NOT include database or vector_store + assert "database" not in call_kwargs + assert "vector_store" not in call_kwargs + # Should include other fields + assert "model" in call_kwargs or "chunk_size" in call_kwargs + + +class TestGetTempDirectory: + """Tests for the get_temp_directory function.""" + + @patch("server.api.utils.embed.Path") + def test_get_temp_directory_uses_app_tmp(self, mock_path): + """Should use /app/tmp if it exists.""" + mock_app_path = MagicMock() + mock_app_path.exists.return_value = True + mock_app_path.is_dir.return_value = True + mock_path.return_value = mock_app_path + mock_path.side_effect = lambda x: mock_app_path if x == "/app/tmp" else MagicMock() + + result = utils_embed.get_temp_directory("test_client", "embed") + + assert result is not None + + @patch("server.api.utils.embed.Path") + def test_get_temp_directory_uses_tmp_fallback(self, mock_path): + """Should use /tmp if /app/tmp doesn't exist.""" + mock_app_path = MagicMock() + mock_app_path.exists.return_value = False + mock_path.return_value = mock_app_path + + result = utils_embed.get_temp_directory("test_client", "embed") + + assert result is not None + + +class TestDocToJson: + """Tests for the doc_to_json function.""" + + def test_doc_to_json_creates_file(self, tmp_path): + """Should create JSON file from documents.""" + docs = [LangchainDocument(page_content="Test content", metadata={"source": "test.pdf"})] + + result = utils_embed.doc_to_json(docs, "test.pdf", str(tmp_path)) + + assert os.path.exists(result) + assert result.endswith(".json") + + +class TestProcessMetadata: + """Tests for the process_metadata function.""" + + def test_process_metadata_adds_metadata(self): + """Should add metadata to chunk.""" + chunk = LangchainDocument(page_content="Test content", metadata={"source": "/path/to/test.pdf", "page": 1}) + + result = utils_embed.process_metadata(1, chunk) + + assert len(result) == 1 + assert result[0].metadata["id"] == "test_1" + assert result[0].metadata["filename"] == "test.pdf" + + def test_process_metadata_includes_file_metadata(self): + """Should include file metadata if provided.""" + chunk = LangchainDocument(page_content="Test content", metadata={"source": "/path/to/doc.pdf"}) + file_metadata = {"doc.pdf": {"size": 1000, "time_modified": "2024-01-01", "etag": "abc123"}} + + result = utils_embed.process_metadata(1, chunk, file_metadata) + + assert result[0].metadata["size"] == 1000 + assert result[0].metadata["etag"] == "abc123" + + +class TestSplitDocument: + """Tests for the split_document function.""" + + def test_split_document_pdf(self): + """Should split PDF documents.""" + docs = [LangchainDocument(page_content="A" * 2000, metadata={"source": "test.pdf"})] + + result = utils_embed.split_document("default", 500, 50, docs, "pdf") + + assert len(result) > 0 + + def test_split_document_unsupported_extension(self): + """Should raise ValueError for unsupported extension.""" + docs = [LangchainDocument(page_content="Test", metadata={})] + + with pytest.raises(ValueError) as exc_info: + utils_embed.split_document("default", 500, 50, docs, "xyz") + + assert "Unsupported file type" in str(exc_info.value) + + +class TestGetDocumentLoader: # pylint: disable=protected-access + """Tests for the _get_document_loader function.""" + + def test_get_document_loader_pdf(self, tmp_path): + """Should return PyPDFLoader for PDF files.""" + test_file = tmp_path / "test.pdf" + test_file.touch() + + _, split = utils_embed._get_document_loader(str(test_file), "pdf") + + assert split is True + + def test_get_document_loader_html(self, tmp_path): + """Should return TextLoader for HTML files.""" + test_file = tmp_path / "test.html" + test_file.touch() + + _, split = utils_embed._get_document_loader(str(test_file), "html") + + assert split is True + + def test_get_document_loader_unsupported(self, tmp_path): + """Should raise ValueError for unsupported extension.""" + test_file = tmp_path / "test.xyz" + test_file.touch() + + with pytest.raises(ValueError): + utils_embed._get_document_loader(str(test_file), "xyz") + + +class TestCaptureFileMetadata: # pylint: disable=protected-access + """Tests for the _capture_file_metadata function.""" + + def test_capture_file_metadata_new_file(self, tmp_path): + """Should capture metadata for new files.""" + test_file = tmp_path / "test.txt" + test_file.write_text("content") + stat = test_file.stat() + file_metadata = {} + + utils_embed._capture_file_metadata("test.txt", stat, file_metadata) + + assert "test.txt" in file_metadata + assert "size" in file_metadata["test.txt"] + assert "time_modified" in file_metadata["test.txt"] + + def test_capture_file_metadata_existing_file(self, tmp_path): + """Should not overwrite existing metadata.""" + test_file = tmp_path / "test.txt" + test_file.write_text("content") + stat = test_file.stat() + file_metadata = {"test.txt": {"size": 9999}} + + utils_embed._capture_file_metadata("test.txt", stat, file_metadata) + + assert file_metadata["test.txt"]["size"] == 9999 # Not overwritten + + +class TestPrepareDocuments: # pylint: disable=protected-access + """Tests for the _prepare_documents function.""" + + def test_prepare_documents_removes_duplicates(self): + """Should remove duplicate documents.""" + docs = [ + LangchainDocument(page_content="Same content", metadata={}), + LangchainDocument(page_content="Same content", metadata={}), + LangchainDocument(page_content="Different content", metadata={}), + ] + + result = utils_embed._prepare_documents(docs) + + assert len(result) == 2 + + +class TestGetVectorStoreByAlias: + """Tests for the get_vector_store_by_alias function.""" + + @patch("server.api.utils.embed.utils_databases.connect") + @patch("server.api.utils.embed.utils_databases.disconnect") + def test_get_vector_store_by_alias_success(self, _mock_disconnect, mock_connect, make_database): + """Should return vector store config for matching alias.""" + mock_cursor = MagicMock() + mock_cursor.fetchall.return_value = [ + ("VS_TEST", '{"alias": "test_alias", "model": "embed-3", "chunk_size": 500, "chunk_overlap": 100}') + ] + mock_conn = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + result = utils_embed.get_vector_store_by_alias(make_database(), "test_alias") + + assert result.vector_store == "VS_TEST" + assert result.alias == "test_alias" + + @patch("server.api.utils.embed.utils_databases.connect") + @patch("server.api.utils.embed.utils_databases.disconnect") + def test_get_vector_store_by_alias_not_found(self, _mock_disconnect, mock_connect, make_database): + """Should raise ValueError if alias not found.""" + mock_cursor = MagicMock() + mock_cursor.fetchall.return_value = [] + mock_conn = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + with pytest.raises(ValueError) as exc_info: + utils_embed.get_vector_store_by_alias(make_database(), "nonexistent") + + assert "not found" in str(exc_info.value) + + +class TestGetTotalChunksCount: + """Tests for the get_total_chunks_count function.""" + + @patch("server.api.utils.embed.utils_databases.connect") + @patch("server.api.utils.embed.utils_databases.disconnect") + def test_get_total_chunks_count_success(self, _mock_disconnect, mock_connect, make_database): + """Should return chunk count.""" + mock_cursor = MagicMock() + mock_cursor.fetchone.return_value = (150,) + mock_conn = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + result = utils_embed.get_total_chunks_count(make_database(), "VS_TEST") + + assert result == 150 + + @patch("server.api.utils.embed.utils_databases.connect") + @patch("server.api.utils.embed.utils_databases.disconnect") + def test_get_total_chunks_count_error(self, _mock_disconnect, mock_connect, make_database): + """Should return 0 on error.""" + mock_cursor = MagicMock() + mock_cursor.execute.side_effect = Exception("Query failed") + mock_conn = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + result = utils_embed.get_total_chunks_count(make_database(), "VS_TEST") + + assert result == 0 + + +class TestGetProcessedObjectsMetadata: + """Tests for the get_processed_objects_metadata function.""" + + @patch("server.api.utils.embed.utils_databases.connect") + @patch("server.api.utils.embed.utils_databases.disconnect") + def test_get_processed_objects_metadata_new_format(self, _mock_disconnect, mock_connect, make_database): + """Should return metadata in new format.""" + mock_cursor = MagicMock() + mock_cursor.fetchall.return_value = [({"filename": "doc.pdf", "etag": "abc", "time_modified": "2024-01-01"},)] + mock_conn = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + result = utils_embed.get_processed_objects_metadata(make_database(), "VS_TEST") + + assert "doc.pdf" in result + assert result["doc.pdf"]["etag"] == "abc" + + @patch("server.api.utils.embed.utils_databases.connect") + @patch("server.api.utils.embed.utils_databases.disconnect") + def test_get_processed_objects_metadata_old_format(self, _mock_disconnect, mock_connect, make_database): + """Should handle old format with source field.""" + mock_cursor = MagicMock() + mock_cursor.fetchall.return_value = [({"source": "/path/to/doc.pdf"},)] + mock_conn = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + result = utils_embed.get_processed_objects_metadata(make_database(), "VS_TEST") + + assert "doc.pdf" in result + + +class TestGetVectorStoreFiles: + """Tests for the get_vector_store_files function.""" + + @patch("server.api.utils.embed.utils_databases.connect") + @patch("server.api.utils.embed.utils_databases.disconnect") + def test_get_vector_store_files_success(self, _mock_disconnect, mock_connect, make_database): + """Should return file list with statistics.""" + mock_cursor = MagicMock() + mock_cursor.fetchall.return_value = [ + ({"filename": "doc1.pdf", "size": 1000},), + ({"filename": "doc1.pdf", "size": 1000},), + ({"filename": "doc2.pdf", "size": 2000},), + ] + mock_conn = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + result = utils_embed.get_vector_store_files(make_database(), "VS_TEST") + + assert result["total_files"] == 2 + assert result["total_chunks"] == 3 + + +class TestRefreshVectorStoreFromBucket: + """Tests for the refresh_vector_store_from_bucket function.""" + + @patch("server.api.utils.embed.get_temp_directory") + def test_refresh_vector_store_empty_objects( + self, _mock_get_temp, make_vector_store, make_database, make_oci_config + ): + """Should return early if no objects to process.""" + result = utils_embed.refresh_vector_store_from_bucket( + make_vector_store(), + "test-bucket", + [], + make_database(), + MagicMock(), + make_oci_config(), + ) + + assert result["processed_files"] == 0 + assert "No new or modified files" in result["message"] + + @patch("server.api.utils.embed.shutil.rmtree") + @patch("server.api.utils.embed.populate_vs") + @patch("server.api.utils.embed.load_and_split_documents") + @patch("server.api.utils.embed.utils_oci.get_object") + @patch("server.api.utils.embed.get_temp_directory") + def test_refresh_vector_store_success( + self, + mock_get_temp, + mock_get_object, + mock_load_split, + mock_populate, + _mock_rmtree, + make_vector_store, + make_database, + make_oci_config, + tmp_path, + ): + """Should process objects and populate vector store.""" + mock_get_temp.return_value = tmp_path + mock_get_object.return_value = str(tmp_path / "doc.pdf") + mock_load_split.return_value = ([LangchainDocument(page_content="test", metadata={})], []) + + bucket_objects = [{"name": "doc.pdf", "size": 1000, "time_modified": "2024-01-01", "etag": "abc"}] + + result = utils_embed.refresh_vector_store_from_bucket( + make_vector_store(), + "test-bucket", + bucket_objects, + make_database(), + MagicMock(), + make_oci_config(), + ) + + assert result["processed_files"] == 1 + mock_populate.assert_called_once() + + @patch("server.api.utils.embed.shutil.rmtree") + @patch("server.api.utils.embed.utils_oci.get_object") + @patch("server.api.utils.embed.get_temp_directory") + def test_refresh_vector_store_download_failure( + self, mock_get_temp, mock_get_object, _mock_rmtree, make_vector_store, make_database, make_oci_config, tmp_path + ): + """Should handle download failures gracefully.""" + mock_get_temp.return_value = tmp_path + mock_get_object.side_effect = Exception("Download failed") + + bucket_objects = [{"name": "doc.pdf", "size": 1000}] + + result = utils_embed.refresh_vector_store_from_bucket( + make_vector_store(), + "test-bucket", + bucket_objects, + make_database(), + MagicMock(), + make_oci_config(), + ) + + assert result["processed_files"] == 0 + assert "errors" in result + + +class TestLoadAndSplitDocuments: + """Tests for the load_and_split_documents function.""" + + @patch("server.api.utils.embed._get_document_loader") + @patch("server.api.utils.embed._process_and_split_document") + def test_load_and_split_documents_success(self, mock_process, mock_get_loader, tmp_path): + """Should load and split documents.""" + test_file = tmp_path / "test.txt" + test_file.write_text("Test content") + + mock_loader = MagicMock() + mock_loader.load.return_value = [LangchainDocument(page_content="Test", metadata={})] + mock_get_loader.return_value = (mock_loader, True) + mock_process.return_value = [LangchainDocument(page_content="Test", metadata={"id": "1"})] + + result, _ = utils_embed.load_and_split_documents([str(test_file)], "default", 500, 50) + + assert len(result) == 1 + + @patch("server.api.utils.embed._get_document_loader") + @patch("server.api.utils.embed._process_and_split_document") + @patch("server.api.utils.embed.doc_to_json") + def test_load_and_split_documents_with_json_output( + self, mock_doc_to_json, mock_process, mock_get_loader, tmp_path + ): + """Should write JSON when output_dir provided.""" + test_file = tmp_path / "test.txt" + test_file.write_text("Test content") + + mock_loader = MagicMock() + mock_loader.load.return_value = [LangchainDocument(page_content="Test", metadata={})] + mock_get_loader.return_value = (mock_loader, True) + mock_process.return_value = [LangchainDocument(page_content="Test", metadata={})] + mock_doc_to_json.return_value = str(tmp_path / "_test.json") + + _, split_files = utils_embed.load_and_split_documents( + [str(test_file)], "default", 500, 50, write_json=True, output_dir=str(tmp_path) + ) + + mock_doc_to_json.assert_called_once() + assert len(split_files) == 1 + + +class TestLoadAndSplitUrl: + """Tests for the load_and_split_url function.""" + + @patch("server.api.utils.embed.WebBaseLoader") + @patch("server.api.utils.embed.split_document") + def test_load_and_split_url_success(self, mock_split, mock_loader_class): + """Should load and split URL content.""" + mock_loader = MagicMock() + mock_loader.load.return_value = [ + LangchainDocument(page_content="Web content", metadata={"source": "http://example.com"}) + ] + mock_loader_class.return_value = mock_loader + mock_split.return_value = [LangchainDocument(page_content="Chunk", metadata={"source": "http://example.com"})] + + result, _ = utils_embed.load_and_split_url("default", "http://example.com", 500, 50) + + assert len(result) == 1 + + @patch("server.api.utils.embed.WebBaseLoader") + @patch("server.api.utils.embed.split_document") + def test_load_and_split_url_empty_content(self, mock_split, mock_loader_class): + """Should raise ValueError for empty content.""" + mock_loader = MagicMock() + mock_loader.load.return_value = [LangchainDocument(page_content="", metadata={})] + mock_loader_class.return_value = mock_loader + mock_split.return_value = [] + + with pytest.raises(ValueError) as exc_info: + utils_embed.load_and_split_url("default", "http://example.com", 500, 50) + + assert "no chunk-able data" in str(exc_info.value) + + +class TestJsonToDoc: # pylint: disable=protected-access + """Tests for the _json_to_doc function.""" + + def test_json_to_doc_success(self, tmp_path): + """Should convert JSON file to documents.""" + json_content = [ + {"kwargs": {"page_content": "Content 1", "metadata": {"source": "test.pdf"}}}, + {"kwargs": {"page_content": "Content 2", "metadata": {"source": "test.pdf"}}}, + ] + json_file = tmp_path / "test.json" + json_file.write_text(json.dumps(json_content)) + + result = utils_embed._json_to_doc(str(json_file)) + + assert len(result) == 2 + assert result[0].page_content == "Content 1" + + +class TestProcessAndSplitDocument: # pylint: disable=protected-access + """Tests for the _process_and_split_document function.""" + + @patch("server.api.utils.embed.split_document") + @patch("server.api.utils.embed.process_metadata") + def test_process_and_split_document_with_split(self, mock_process_meta, mock_split): + """Should split and process document.""" + mock_split.return_value = [LangchainDocument(page_content="Chunk", metadata={"source": "test.pdf"})] + mock_process_meta.return_value = [LangchainDocument(page_content="Chunk", metadata={"id": "1"})] + + loaded_doc = [LangchainDocument(page_content="Full content", metadata={})] + + result = utils_embed._process_and_split_document( + loaded_doc, + split=True, + model="default", + chunk_size=500, + chunk_overlap=50, + extension="pdf", + file_metadata={}, + ) + + mock_split.assert_called_once() + assert len(result) == 1 + + def test_process_and_split_document_no_split(self): + """Should return loaded doc without splitting.""" + loaded_doc = [LangchainDocument(page_content="Content", metadata={})] + + result = utils_embed._process_and_split_document( + loaded_doc, + split=False, + model="default", + chunk_size=500, + chunk_overlap=50, + extension="png", + file_metadata={}, + ) + + assert result == loaded_doc + + +class TestCreateTempVectorStore: # pylint: disable=protected-access + """Tests for the _create_temp_vector_store function.""" + + @patch("server.api.utils.embed.utils_databases.drop_vs") + @patch("server.api.utils.embed.OracleVS") + def test_create_temp_vector_store_success(self, mock_oracle_vs, mock_drop_vs, make_vector_store): + """Should create temporary vector store.""" + mock_vs = MagicMock() + mock_oracle_vs.return_value = mock_vs + mock_conn = MagicMock() + mock_embed_client = MagicMock() + vector_store = make_vector_store(vector_store="VS_TEST") + + _, vs_config_tmp = utils_embed._create_temp_vector_store(mock_conn, vector_store, mock_embed_client) + + assert vs_config_tmp.vector_store == "VS_TEST_TMP" + mock_drop_vs.assert_called_once() + + +class TestEmbedDocumentsInBatches: # pylint: disable=protected-access + """Tests for the _embed_documents_in_batches function.""" + + @patch("server.api.utils.embed.OracleVS.add_documents") + def test_embed_documents_in_batches_no_rate_limit(self, mock_add_docs): + """Should embed documents without rate limiting.""" + mock_vs = MagicMock() + chunks = [LangchainDocument(page_content=f"Chunk {i}", metadata={}) for i in range(10)] + + utils_embed._embed_documents_in_batches(mock_vs, chunks, rate_limit=0) + + mock_add_docs.assert_called_once() + + @patch("server.api.utils.embed.time.sleep") + @patch("server.api.utils.embed.OracleVS.add_documents") + def test_embed_documents_in_batches_with_rate_limit(self, mock_add_docs, mock_sleep): + """Should apply rate limiting between batches.""" + mock_vs = MagicMock() + # Create 600 chunks to trigger multiple batches (batch_size=500) + chunks = [LangchainDocument(page_content=f"Chunk {i}", metadata={}) for i in range(600)] + + utils_embed._embed_documents_in_batches(mock_vs, chunks, rate_limit=60) + + assert mock_add_docs.call_count == 2 # 500 + 100 + mock_sleep.assert_called() # Rate limiting applied + + +class TestMergeAndIndexVectorStore: # pylint: disable=protected-access + """Tests for the _merge_and_index_vector_store function.""" + + @patch("server.api.utils.embed.LangchainVS.create_index") + @patch("server.api.utils.embed.utils_databases.drop_vs") + @patch("server.api.utils.embed.utils_databases.execute_sql") + @patch("server.api.utils.embed.LangchainVS.drop_index_if_exists") + @patch("server.api.utils.embed.OracleVS") + def test_merge_and_index_vector_store_hnsw( + self, _mock_oracle_vs, mock_drop_idx, mock_execute, mock_drop_vs, mock_create_idx, make_vector_store + ): + """Should merge temp store and create HNSW index.""" + mock_conn = MagicMock() + vector_store = make_vector_store(vector_store="VS_TEST", index_type="HNSW") + vector_store_tmp = make_vector_store(vector_store="VS_TEST_TMP") + + utils_embed._merge_and_index_vector_store(mock_conn, vector_store, vector_store_tmp, MagicMock()) + + mock_drop_idx.assert_called_once() # HNSW drops existing index + mock_execute.assert_called_once() # Merge SQL + mock_drop_vs.assert_called_once() # Drop temp table + mock_create_idx.assert_called_once() # Create index + + +class TestPopulateVs: + """Tests for the populate_vs function.""" + + @patch("server.api.utils.embed.update_vs_comment") + @patch("server.api.utils.embed._merge_and_index_vector_store") + @patch("server.api.utils.embed._embed_documents_in_batches") + @patch("server.api.utils.embed._create_temp_vector_store") + @patch("server.api.utils.embed.utils_databases.connect") + @patch("server.api.utils.embed._prepare_documents") + def test_populate_vs_success( + self, + mock_prepare, + mock_connect, + mock_create_temp, + mock_embed, + mock_merge, + mock_comment, + make_vector_store, + make_database, + ): + """Should populate vector store with documents.""" + mock_prepare.return_value = [LangchainDocument(page_content="Test", metadata={})] + mock_conn = MagicMock() + mock_connect.return_value = mock_conn + mock_create_temp.return_value = (MagicMock(), make_vector_store(vector_store="VS_TMP")) + + docs = [LangchainDocument(page_content="Test", metadata={})] + + utils_embed.populate_vs(make_vector_store(), make_database(), MagicMock(), input_data=docs) + + mock_prepare.assert_called_once() + mock_create_temp.assert_called_once() + mock_embed.assert_called_once() + mock_merge.assert_called_once() + mock_comment.assert_called_once() + + +class TestSplitDocumentExtensions: + """Tests for split_document with various extensions.""" + + def test_split_document_html(self): + """Should split HTML documents using HTMLHeaderTextSplitter.""" + docs = [LangchainDocument(page_content="

Title

Content here

", metadata={"source": "test.html"})] + + result = utils_embed.split_document("default", 500, 50, docs, "html") + + assert len(result) >= 1 + + def test_split_document_md(self): + """Should split Markdown documents.""" + docs = [LangchainDocument(page_content="# Header\n\nContent " * 100, metadata={"source": "test.md"})] + + result = utils_embed.split_document("default", 500, 50, docs, "md") + + assert len(result) >= 1 + + def test_split_document_txt(self): + """Should split text documents.""" + docs = [LangchainDocument(page_content="Text content " * 200, metadata={"source": "test.txt"})] + + result = utils_embed.split_document("default", 500, 50, docs, "txt") + + assert len(result) >= 1 + + def test_split_document_csv(self): + """Should split CSV documents.""" + docs = [LangchainDocument(page_content="col1,col2\nval1,val2\n" * 100, metadata={"source": "test.csv"})] + + result = utils_embed.split_document("default", 500, 50, docs, "csv") + + assert len(result) >= 1 + + +class TestGetDocumentLoaderExtensions: # pylint: disable=protected-access + """Tests for _get_document_loader with various extensions.""" + + def test_get_document_loader_md(self, tmp_path): + """Should return TextLoader for Markdown files.""" + test_file = tmp_path / "test.md" + test_file.touch() + + _, split = utils_embed._get_document_loader(str(test_file), "md") + + assert split is True + + def test_get_document_loader_csv(self, tmp_path): + """Should return CSVLoader for CSV files.""" + test_file = tmp_path / "test.csv" + test_file.write_text("col1,col2\nval1,val2") + + _, split = utils_embed._get_document_loader(str(test_file), "csv") + + assert split is True + + def test_get_document_loader_txt(self, tmp_path): + """Should return TextLoader for text files.""" + test_file = tmp_path / "test.txt" + test_file.touch() + + _, split = utils_embed._get_document_loader(str(test_file), "txt") + + assert split is True + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured.""" + assert hasattr(utils_embed, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert utils_embed.logger.name == "api.utils.embed" diff --git a/test/unit/server/api/utils/test_utils_mcp.py b/test/unit/server/api/utils/test_utils_mcp.py new file mode 100644 index 00000000..0301a321 --- /dev/null +++ b/test/unit/server/api/utils/test_utils_mcp.py @@ -0,0 +1,192 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/utils/mcp.py +Tests for MCP utility functions. +""" + +from unittest.mock import patch, MagicMock, AsyncMock +import os +import pytest + +from server.api.utils import mcp + + +class TestGetClient: + """Tests for the get_client function.""" + + @patch.dict(os.environ, {"API_SERVER_KEY": "test-api-key"}) + def test_get_client_default_values(self): + """get_client should return default configuration.""" + result = mcp.get_client() + + assert "mcpServers" in result + assert "optimizer" in result["mcpServers"] + assert result["mcpServers"]["optimizer"]["type"] == "streamableHttp" + assert result["mcpServers"]["optimizer"]["transport"] == "streamable_http" + assert "http://127.0.0.1:8000/mcp/" in result["mcpServers"]["optimizer"]["url"] + + @patch.dict(os.environ, {"API_SERVER_KEY": "test-api-key"}) + def test_get_client_custom_server_port(self): + """get_client should use custom server and port.""" + result = mcp.get_client(server="http://custom.server", port=9000) + + assert "http://custom.server:9000/mcp/" in result["mcpServers"]["optimizer"]["url"] + + @patch.dict(os.environ, {"API_SERVER_KEY": "secret-key"}) + def test_get_client_includes_auth_header(self): + """get_client should include authorization header.""" + result = mcp.get_client() + + headers = result["mcpServers"]["optimizer"]["headers"] + assert "Authorization" in headers + assert headers["Authorization"] == "Bearer secret-key" + + @patch.dict(os.environ, {"API_SERVER_KEY": "test-key"}) + def test_get_client_langgraph_removes_type(self): + """get_client should remove type field for langgraph client.""" + result = mcp.get_client(client="langgraph") + + assert "type" not in result["mcpServers"]["optimizer"] + assert "transport" in result["mcpServers"]["optimizer"] + + @patch.dict(os.environ, {"API_SERVER_KEY": "test-key"}) + def test_get_client_non_langgraph_keeps_type(self): + """get_client should keep type field for non-langgraph clients.""" + result = mcp.get_client(client="other") + + assert "type" in result["mcpServers"]["optimizer"] + + @patch.dict(os.environ, {"API_SERVER_KEY": "test-key"}) + def test_get_client_none_client_keeps_type(self): + """get_client should keep type field when client is None.""" + result = mcp.get_client(client=None) + + assert "type" in result["mcpServers"]["optimizer"] + + @patch.dict(os.environ, {"API_SERVER_KEY": ""}) + def test_get_client_empty_api_key(self): + """get_client should handle empty API key.""" + result = mcp.get_client() + + headers = result["mcpServers"]["optimizer"]["headers"] + assert headers["Authorization"] == "Bearer " + + @patch.dict(os.environ, {"API_SERVER_KEY": "key"}) + def test_get_client_structure(self): + """get_client should return expected structure.""" + result = mcp.get_client() + + assert isinstance(result, dict) + assert isinstance(result["mcpServers"], dict) + assert isinstance(result["mcpServers"]["optimizer"], dict) + + optimizer = result["mcpServers"]["optimizer"] + expected_keys = {"type", "transport", "url", "headers"} + assert set(optimizer.keys()) == expected_keys + + +class TestListPrompts: + """Tests for the list_prompts function.""" + + @pytest.mark.asyncio + @patch("server.api.utils.mcp.Client") + async def test_list_prompts_success(self, mock_client_class): + """list_prompts should return list of prompts.""" + mock_prompts = [MagicMock(name="prompt1"), MagicMock(name="prompt2")] + + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.list_prompts = AsyncMock(return_value=mock_prompts) + mock_client.close = AsyncMock() + mock_client_class.return_value = mock_client + + mock_mcp_engine = MagicMock() + + result = await mcp.list_prompts(mock_mcp_engine) + + assert result == mock_prompts + mock_client.list_prompts.assert_called_once() + + @pytest.mark.asyncio + @patch("server.api.utils.mcp.Client") + async def test_list_prompts_empty_list(self, mock_client_class): + """list_prompts should return empty list when no prompts.""" + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.list_prompts = AsyncMock(return_value=[]) + mock_client.close = AsyncMock() + mock_client_class.return_value = mock_client + + mock_mcp_engine = MagicMock() + + result = await mcp.list_prompts(mock_mcp_engine) + + assert result == [] + + @pytest.mark.asyncio + @patch("server.api.utils.mcp.Client") + async def test_list_prompts_closes_client(self, mock_client_class): + """list_prompts should close client after use.""" + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.list_prompts = AsyncMock(return_value=[]) + mock_client.close = AsyncMock() + mock_client_class.return_value = mock_client + + mock_mcp_engine = MagicMock() + + await mcp.list_prompts(mock_mcp_engine) + + mock_client.close.assert_called_once() + + @pytest.mark.asyncio + @patch("server.api.utils.mcp.Client") + async def test_list_prompts_creates_client_with_engine(self, mock_client_class): + """list_prompts should create client with MCP engine.""" + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.list_prompts = AsyncMock(return_value=[]) + mock_client.close = AsyncMock() + mock_client_class.return_value = mock_client + + mock_mcp_engine = MagicMock() + + await mcp.list_prompts(mock_mcp_engine) + + mock_client_class.assert_called_once_with(mock_mcp_engine) + + @pytest.mark.asyncio + @patch("server.api.utils.mcp.Client") + async def test_list_prompts_closes_client_on_exception(self, mock_client_class): + """list_prompts should close client even if exception occurs.""" + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.list_prompts = AsyncMock(side_effect=RuntimeError("Test error")) + mock_client.close = AsyncMock() + mock_client_class.return_value = mock_client + + mock_mcp_engine = MagicMock() + + with pytest.raises(RuntimeError): + await mcp.list_prompts(mock_mcp_engine) + + mock_client.close.assert_called_once() + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured.""" + assert hasattr(mcp, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert mcp.logger.name == "api.utils.mcp" diff --git a/test/unit/server/api/utils/test_utils_models.py b/test/unit/server/api/utils/test_utils_models.py new file mode 100644 index 00000000..8616ca9e --- /dev/null +++ b/test/unit/server/api/utils/test_utils_models.py @@ -0,0 +1,433 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/utils/models.py +Tests for model utility functions. +""" + +# pylint: disable=too-few-public-methods + +from unittest.mock import patch, MagicMock +import pytest + +from server.api.utils import models as utils_models +from server.api.utils.models import ( + URLUnreachableError, + InvalidModelError, + ExistsModelError, + UnknownModelError, +) + + +class TestExceptions: + """Tests for custom exception classes.""" + + def test_url_unreachable_error_is_value_error(self): + """URLUnreachableError should inherit from ValueError.""" + exc = URLUnreachableError("URL unreachable") + assert isinstance(exc, ValueError) + + def test_invalid_model_error_is_value_error(self): + """InvalidModelError should inherit from ValueError.""" + exc = InvalidModelError("Invalid model") + assert isinstance(exc, ValueError) + + def test_exists_model_error_is_value_error(self): + """ExistsModelError should inherit from ValueError.""" + exc = ExistsModelError("Model exists") + assert isinstance(exc, ValueError) + + def test_unknown_model_error_is_value_error(self): + """UnknownModelError should inherit from ValueError.""" + exc = UnknownModelError("Model not found") + assert isinstance(exc, ValueError) + + +class TestCreate: + """Tests for the create function.""" + + @patch("server.api.utils.models.get") + @patch("server.api.utils.models.MODEL_OBJECTS", []) + def test_create_success(self, mock_get, make_model): + """create should add model to MODEL_OBJECTS.""" + model = make_model(model_id="gpt-4", provider="openai") + mock_get.side_effect = [UnknownModelError("Not found"), (model,)] + + result = utils_models.create(model) + + assert result == model + + @patch("server.api.utils.models.get") + def test_create_raises_exists_error(self, mock_get, make_model): + """create should raise ExistsModelError if model exists.""" + model = make_model(model_id="gpt-4", provider="openai") + mock_get.return_value = [model] + + with pytest.raises(ExistsModelError): + utils_models.create(model) + + @patch("server.api.utils.models.get") + @patch("server.api.utils.models.is_url_accessible") + @patch("server.api.utils.models.MODEL_OBJECTS", []) + def test_create_disables_model_if_url_inaccessible(self, mock_url_check, mock_get, make_model): + """create should disable model if API base URL is inaccessible.""" + model = make_model(model_id="custom", provider="openai") + model.api_base = "https://unreachable.example.com" + mock_get.side_effect = [UnknownModelError("Not found"), (model,)] + mock_url_check.return_value = (False, "Connection refused") + + result = utils_models.create(model, check_url=True) + + assert result.enabled is False + + +class TestGet: + """Tests for the get function.""" + + @patch("server.api.utils.models.MODEL_OBJECTS") + def test_get_all_models(self, mock_objects, make_model): + """get should return all models when no filters.""" + model1 = make_model(model_id="gpt-4", provider="openai") + model2 = make_model(model_id="claude-3", provider="anthropic") + mock_objects.__iter__ = lambda _: iter([model1, model2]) + mock_objects.__len__ = lambda _: 2 + + result = utils_models.get() + + assert len(result) == 2 + + @patch("server.api.utils.models.MODEL_OBJECTS") + def test_get_by_provider(self, mock_objects, make_model): + """get should filter by provider.""" + model1 = make_model(model_id="gpt-4", provider="openai") + model2 = make_model(model_id="claude-3", provider="anthropic") + mock_objects.__iter__ = lambda _: iter([model1, model2]) + mock_objects.__len__ = lambda _: 2 + + result = utils_models.get(model_provider="openai") + + assert len(result) == 1 + assert result[0].provider == "openai" + + @patch("server.api.utils.models.MODEL_OBJECTS") + def test_get_by_type(self, mock_objects, make_model): + """get should filter by type.""" + model1 = make_model(model_id="gpt-4", model_type="ll") + model2 = make_model(model_id="embed-3", model_type="embed") + mock_objects.__iter__ = lambda _: iter([model1, model2]) + mock_objects.__len__ = lambda _: 2 + + result = utils_models.get(model_type="embed") + + assert len(result) == 1 + assert result[0].type == "embed" + + @patch("server.api.utils.models.MODEL_OBJECTS") + def test_get_exclude_disabled(self, mock_objects, make_model): + """get should exclude disabled models when include_disabled=False.""" + model1 = make_model(model_id="gpt-4", enabled=True) + model2 = make_model(model_id="gpt-3", enabled=False) + mock_objects.__iter__ = lambda _: iter([model1, model2]) + mock_objects.__len__ = lambda _: 2 + + result = utils_models.get(include_disabled=False) + + assert len(result) == 1 + assert result[0].enabled is True + + @patch("server.api.utils.models.MODEL_OBJECTS") + def test_get_raises_unknown_error(self, mock_objects): + """get should raise UnknownModelError if model_id not found.""" + mock_objects.__iter__ = lambda _: iter([]) + mock_objects.__len__ = lambda _: 0 + + with pytest.raises(UnknownModelError): + utils_models.get(model_id="nonexistent") + + +class TestUpdate: + """Tests for the update function.""" + + @patch("server.api.utils.models.get") + @patch("server.api.utils.models.is_url_accessible") + def test_update_success(self, mock_url_check, mock_get, make_model): + """update should update model in place.""" + existing_model = make_model(model_id="gpt-4", provider="openai") + mock_get.return_value = (existing_model,) + mock_url_check.return_value = (True, "OK") + + payload = make_model(model_id="gpt-4", provider="openai") + payload.temperature = 0.9 + + result = utils_models.update(payload) + + assert result.temperature == 0.9 + + @patch("server.api.utils.models.get") + @patch("server.api.utils.models.is_url_accessible") + def test_update_raises_url_unreachable(self, mock_url_check, mock_get, make_model): + """update should raise URLUnreachableError if URL inaccessible.""" + existing_model = make_model(model_id="gpt-4", provider="openai") + mock_get.return_value = (existing_model,) + mock_url_check.return_value = (False, "Connection refused") + + payload = make_model(model_id="gpt-4", provider="openai", enabled=True) + payload.api_base = "https://unreachable.example.com" + + with pytest.raises(URLUnreachableError): + utils_models.update(payload) + + +class TestDelete: + """Tests for the delete function.""" + + def test_delete_removes_model(self, make_model): + """delete should remove model from MODEL_OBJECTS.""" + model1 = make_model(model_id="gpt-4", provider="openai") + model2 = make_model(model_id="claude-3", provider="anthropic") + + with patch("server.api.utils.models.MODEL_OBJECTS", [model1, model2]) as mock_objects: + utils_models.delete("openai", "gpt-4") + assert len(mock_objects) == 1 + assert mock_objects[0].id == "claude-3" + + +class TestGetSupported: + """Tests for the get_supported function.""" + + @patch("server.api.utils.models.litellm") + def test_get_supported_returns_providers(self, mock_litellm): + """get_supported should return list of providers.""" + mock_provider = MagicMock() + mock_provider.value = "openai" + mock_litellm.provider_list = [mock_provider] + mock_litellm.models_by_provider = {"openai": ["gpt-4"]} + mock_litellm.get_model_info.return_value = {"mode": "chat", "key": "gpt-4"} + mock_litellm.get_llm_provider.return_value = ("openai", None, None, "https://api.openai.com/v1") + + result = utils_models.get_supported() + + assert len(result) >= 1 + assert result[0]["provider"] == "openai" + + @patch("server.api.utils.models.litellm") + def test_get_supported_filters_by_provider(self, mock_litellm): + """get_supported should filter by provider.""" + mock_provider1 = MagicMock() + mock_provider1.value = "openai" + mock_provider2 = MagicMock() + mock_provider2.value = "anthropic" + mock_litellm.provider_list = [mock_provider1, mock_provider2] + mock_litellm.models_by_provider = {"openai": [], "anthropic": []} + + result = utils_models.get_supported(model_provider="anthropic") + + assert len(result) == 1 + assert result[0]["provider"] == "anthropic" + + +class TestCreateGenai: + """Tests for the create_genai function.""" + + @patch("server.api.utils.models.utils_oci.get_genai_models") + @patch("server.api.utils.models.get") + @patch("server.api.utils.models.delete") + @patch("server.api.utils.models.create") + def test_create_genai_creates_models(self, mock_create, _mock_delete, mock_get, mock_get_genai, make_oci_config): + """create_genai should create GenAI models.""" + mock_get_genai.return_value = [ + {"model_name": "cohere.command-r", "capabilities": ["CHAT"]}, + {"model_name": "cohere.embed-v3", "capabilities": ["TEXT_EMBEDDINGS"]}, + ] + mock_get.return_value = [] + + config = make_oci_config(genai_region="us-chicago-1") + config.genai_compartment_id = "ocid1.compartment.oc1..test" + + utils_models.create_genai(config) + + assert mock_create.call_count == 2 + + @patch("server.api.utils.models.utils_oci.get_genai_models") + def test_create_genai_returns_empty_when_no_models(self, mock_get_genai, make_oci_config): + """create_genai should return empty list when no models.""" + mock_get_genai.return_value = [] + + config = make_oci_config(genai_region="us-chicago-1") + + result = utils_models.create_genai(config) + + assert not result + + +class TestGetFullConfig: # pylint: disable=protected-access + """Tests for the _get_full_config function.""" + + @patch("server.api.utils.models.get") + def test_get_full_config_success(self, mock_get, make_model): + """_get_full_config should merge model config with defined model.""" + defined_model = make_model(model_id="gpt-4", provider="openai") + defined_model.api_base = "https://api.openai.com/v1" + mock_get.return_value = (defined_model,) + + model_config = {"model": "openai/gpt-4", "temperature": 0.9} + + full_config, provider = utils_models._get_full_config(model_config, None) + + assert provider == "openai" + assert full_config["temperature"] == 0.9 + assert full_config["api_base"] == "https://api.openai.com/v1" + + @patch("server.api.utils.models.get") + def test_get_full_config_raises_unknown_model(self, mock_get): + """_get_full_config should raise UnknownModelError if not found.""" + mock_get.side_effect = UnknownModelError("Model not found") + + model_config = {"model": "openai/nonexistent"} + + with pytest.raises(UnknownModelError): + utils_models._get_full_config(model_config, None) + + +class TestGetLitellmConfig: + """Tests for the get_litellm_config function.""" + + @patch("server.api.utils.models._get_full_config") + @patch("server.api.utils.models.litellm.get_supported_openai_params") + def test_get_litellm_config_basic(self, mock_get_params, mock_get_full): + """get_litellm_config should return LiteLLM config.""" + mock_get_full.return_value = ( + {"model": "openai/gpt-4", "temperature": 0.7, "api_base": "https://api.openai.com/v1"}, + "openai", + ) + mock_get_params.return_value = ["temperature", "max_tokens"] + + model_config = {"model": "openai/gpt-4"} + + result = utils_models.get_litellm_config(model_config, None) + + assert result["model"] == "openai/gpt-4" + assert result["drop_params"] is True + + @patch("server.api.utils.models._get_full_config") + @patch("server.api.utils.models.litellm.get_supported_openai_params") + @patch("server.api.utils.models.utils_oci.get_signer") + def test_get_litellm_config_oci_provider(self, mock_get_signer, mock_get_params, mock_get_full, make_oci_config): + """get_litellm_config should include OCI params for OCI provider.""" + mock_get_full.return_value = ( + { + "model": "oci/cohere.command-r", + "api_base": "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + }, + "oci", + ) + mock_get_params.return_value = ["temperature"] + mock_get_signer.return_value = None # API key auth + + oci_config = make_oci_config(genai_region="us-chicago-1") + oci_config.genai_compartment_id = "ocid1.compartment.oc1..test" + oci_config.tenancy = "test-tenancy" + oci_config.user = "test-user" + oci_config.fingerprint = "test-fingerprint" + oci_config.key_file = "/path/to/key" + + model_config = {"model": "oci/cohere.command-r"} + + result = utils_models.get_litellm_config(model_config, oci_config) + + assert result["oci_region"] == "us-chicago-1" + assert result["oci_compartment_id"] == "ocid1.compartment.oc1..test" + + +class TestGetClientEmbed: + """Tests for the get_client_embed function.""" + + @patch("server.api.utils.models._get_full_config") + @patch("server.api.utils.models.utils_oci.init_genai_client") + @patch("server.api.utils.models.OCIGenAIEmbeddings") + def test_get_client_embed_oci(self, mock_embeddings, mock_init_client, mock_get_full, make_oci_config): + """get_client_embed should return OCIGenAIEmbeddings for OCI provider.""" + mock_get_full.return_value = ({"id": "cohere.embed-v3"}, "oci") + mock_init_client.return_value = MagicMock() + mock_embeddings.return_value = MagicMock() + + oci_config = make_oci_config() + oci_config.genai_compartment_id = "ocid1.compartment.oc1..test" + + model_config = {"model": "oci/cohere.embed-v3"} + + utils_models.get_client_embed(model_config, oci_config) + + mock_embeddings.assert_called_once() + + @patch("server.api.utils.models._get_full_config") + @patch("server.api.utils.models.init_embeddings") + def test_get_client_embed_openai(self, mock_init_embeddings, mock_get_full, make_oci_config): + """get_client_embed should use init_embeddings for non-OCI providers.""" + mock_get_full.return_value = ( + {"id": "text-embedding-3-small", "api_base": "https://api.openai.com/v1"}, + "openai", + ) + mock_init_embeddings.return_value = MagicMock() + + oci_config = make_oci_config() + model_config = {"model": "openai/text-embedding-3-small"} + + utils_models.get_client_embed(model_config, oci_config) + + mock_init_embeddings.assert_called_once() + + +class TestProcessModelEntry: # pylint: disable=protected-access + """Tests for the _process_model_entry function.""" + + @patch("server.api.utils.models.litellm") + def test_process_model_entry_success(self, mock_litellm): + """_process_model_entry should return model dict.""" + mock_litellm.get_model_info.return_value = {"mode": "chat", "key": "gpt-4"} + mock_litellm.get_llm_provider.return_value = ("openai", None, None, "https://api.openai.com/v1") + + type_to_modes = {"ll": {"chat"}} + allowed_modes = {"chat"} + + result = utils_models._process_model_entry("gpt-4", type_to_modes, allowed_modes, "openai") + + assert result is not None + assert result["type"] == "ll" + + @patch("server.api.utils.models.litellm") + def test_process_model_entry_filters_mode(self, mock_litellm): + """_process_model_entry should return None for unsupported modes.""" + mock_litellm.get_model_info.return_value = {"mode": "moderation"} + + type_to_modes = {"ll": {"chat"}} + allowed_modes = {"chat"} + + result = utils_models._process_model_entry("mod-model", type_to_modes, allowed_modes, "openai") + + assert result is None + + @patch("server.api.utils.models.litellm") + def test_process_model_entry_handles_exception(self, mock_litellm): + """_process_model_entry should handle exceptions gracefully.""" + mock_litellm.get_model_info.side_effect = Exception("API error") + + type_to_modes = {"ll": {"chat"}} + allowed_modes = {"chat"} + + result = utils_models._process_model_entry("bad-model", type_to_modes, allowed_modes, "openai") + + assert result == {"key": "bad-model"} + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured.""" + assert hasattr(utils_models, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert utils_models.logger.name == "api.utils.models" diff --git a/test/unit/server/api/utils/test_utils_oci.py b/test/unit/server/api/utils/test_utils_oci.py new file mode 100644 index 00000000..3e6d9f2b --- /dev/null +++ b/test/unit/server/api/utils/test_utils_oci.py @@ -0,0 +1,595 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/utils/oci.py +Tests for OCI utility functions. +""" + +# pylint: disable=too-few-public-methods + +from datetime import datetime +from unittest.mock import patch, MagicMock + +import pytest +import oci + +from server.api.utils import oci as utils_oci +from server.api.utils.oci import OciException + + +class TestOciException: + """Tests for OciException class.""" + + def test_oci_exception_init(self): + """OciException should store status_code and detail.""" + exc = OciException(status_code=404, detail="Not found") + assert exc.status_code == 404 + assert exc.detail == "Not found" + + def test_oci_exception_message(self): + """OciException should use detail as message.""" + exc = OciException(status_code=500, detail="Server error") + assert str(exc) == "Server error" + + +class TestGet: + """Tests for the get function.""" + + @patch("server.api.utils.oci.bootstrap.OCI_OBJECTS", []) + def test_get_raises_value_error_when_not_configured(self): + """get should raise ValueError when no OCI objects configured.""" + with pytest.raises(ValueError) as exc_info: + utils_oci.get() + assert "not configured" in str(exc_info.value) + + @patch("server.api.utils.oci.bootstrap.OCI_OBJECTS") + def test_get_returns_all_oci_objects(self, mock_objects, make_oci_config): + """get should return all OCI objects when no filters.""" + oci1 = make_oci_config(auth_profile="PROFILE1") + oci2 = make_oci_config(auth_profile="PROFILE2") + mock_objects.__iter__ = lambda _: iter([oci1, oci2]) + mock_objects.__len__ = lambda _: 2 + mock_objects.__bool__ = lambda _: True + + result = utils_oci.get() + + assert len(result) == 2 + + @patch("server.api.utils.oci.bootstrap.OCI_OBJECTS") + def test_get_by_auth_profile(self, mock_objects, make_oci_config): + """get should return matching OCI object by auth_profile.""" + oci1 = make_oci_config(auth_profile="PROFILE1") + oci2 = make_oci_config(auth_profile="PROFILE2") + mock_objects.__iter__ = lambda _: iter([oci1, oci2]) + + result = utils_oci.get(auth_profile="PROFILE1") + + assert result.auth_profile == "PROFILE1" + + @patch("server.api.utils.oci.bootstrap.OCI_OBJECTS") + def test_get_raises_value_error_profile_not_found(self, mock_objects, make_oci_config): + """get should raise ValueError when profile not found.""" + mock_objects.__iter__ = lambda _: iter([make_oci_config(auth_profile="DEFAULT")]) + + with pytest.raises(ValueError) as exc_info: + utils_oci.get(auth_profile="NONEXISTENT") + + assert "not found" in str(exc_info.value) + + def test_get_raises_value_error_both_params(self): + """get should raise ValueError when both client and auth_profile provided.""" + with pytest.raises(ValueError) as exc_info: + utils_oci.get(client="test", auth_profile="DEFAULT") + + assert "not both" in str(exc_info.value) + + @patch("server.api.utils.oci.bootstrap.SETTINGS_OBJECTS") + @patch("server.api.utils.oci.bootstrap.OCI_OBJECTS") + def test_get_by_client(self, mock_oci, mock_settings, make_oci_config, make_settings): + """get should return OCI object based on client settings.""" + settings = make_settings(client="test_client") + settings.oci.auth_profile = "CLIENT_PROFILE" + mock_settings.__iter__ = lambda _: iter([settings]) + mock_settings.__len__ = lambda _: 1 + + oci_config = make_oci_config(auth_profile="CLIENT_PROFILE") + mock_oci.__iter__ = lambda _: iter([oci_config]) + + result = utils_oci.get(client="test_client") + + assert result.auth_profile == "CLIENT_PROFILE" + + @patch("server.api.utils.oci.bootstrap.SETTINGS_OBJECTS", []) + def test_get_raises_value_error_client_not_found(self): + """get should raise ValueError when client not found.""" + with pytest.raises(ValueError) as exc_info: + utils_oci.get(client="nonexistent") + + assert "not found" in str(exc_info.value) + + +class TestGetSigner: + """Tests for the get_signer function.""" + + @patch("server.api.utils.oci.oci.auth.signers.InstancePrincipalsSecurityTokenSigner") + def test_get_signer_instance_principal(self, mock_signer_class, make_oci_config): + """get_signer should return instance principal signer.""" + mock_signer = MagicMock() + mock_signer_class.return_value = mock_signer + config = make_oci_config() + config.authentication = "instance_principal" + + result = utils_oci.get_signer(config) + + assert result == mock_signer + mock_signer_class.assert_called_once() + + @patch("server.api.utils.oci.oci.auth.signers.get_oke_workload_identity_resource_principal_signer") + def test_get_signer_oke_workload_identity(self, mock_signer_func, make_oci_config): + """get_signer should return OKE workload identity signer.""" + mock_signer = MagicMock() + mock_signer_func.return_value = mock_signer + config = make_oci_config() + config.authentication = "oke_workload_identity" + + result = utils_oci.get_signer(config) + + assert result == mock_signer + + def test_get_signer_api_key_returns_none(self, make_oci_config): + """get_signer should return None for API key authentication.""" + config = make_oci_config() + config.authentication = "api_key" + + result = utils_oci.get_signer(config) + + assert result is None + + +class TestInitClient: + """Tests for the init_client function.""" + + @patch("server.api.utils.oci.get_signer") + @patch("server.api.utils.oci.oci.object_storage.ObjectStorageClient") + def test_init_client_standard_auth(self, mock_client_class, mock_get_signer, make_oci_config): + """init_client should initialize with standard authentication.""" + mock_get_signer.return_value = None + mock_client = MagicMock() + mock_client_class.return_value = mock_client + config = make_oci_config() + + result = utils_oci.init_client(oci.object_storage.ObjectStorageClient, config) + + assert result == mock_client + + @patch("server.api.utils.oci.get_signer") + @patch("server.api.utils.oci.oci.object_storage.ObjectStorageClient") + def test_init_client_with_signer(self, mock_client_class, mock_get_signer, make_oci_config): + """init_client should use signer when provided.""" + mock_signer = MagicMock() + mock_signer.tenancy_id = "test-tenancy-id" + mock_get_signer.return_value = mock_signer + mock_client = MagicMock() + mock_client_class.return_value = mock_client + config = make_oci_config() + config.authentication = "instance_principal" + config.region = "us-ashburn-1" # Required for signer-based auth + config.tenancy = "existing-tenancy" # Set tenancy so code doesn't try to derive from signer + + result = utils_oci.init_client(oci.object_storage.ObjectStorageClient, config) + + assert result == mock_client + # Check signer was passed to client + call_kwargs = mock_client_class.call_args.kwargs + assert call_kwargs["signer"] == mock_signer + + @patch("server.api.utils.oci.get_signer") + def test_init_client_raises_oci_exception_on_invalid_config(self, mock_get_signer, make_oci_config): + """init_client should raise OciException on invalid config.""" + mock_get_signer.return_value = None + config = make_oci_config() + + with patch("server.api.utils.oci.oci.object_storage.ObjectStorageClient") as mock_client: + mock_client.side_effect = oci.exceptions.InvalidConfig("Invalid configuration") + + with pytest.raises(OciException) as exc_info: + utils_oci.init_client(oci.object_storage.ObjectStorageClient, config) + + assert exc_info.value.status_code == 400 + + @patch("server.api.utils.oci.get_signer") + @patch("server.api.utils.oci.oci.generative_ai_inference.GenerativeAiInferenceClient") + def test_init_client_genai_sets_service_endpoint(self, mock_client_class, mock_get_signer, make_oci_config): + """init_client should set service endpoint for GenAI client.""" + mock_get_signer.return_value = None + mock_client = MagicMock() + mock_client_class.return_value = mock_client + config = make_oci_config(genai_region="us-chicago-1") + config.genai_compartment_id = "ocid1.compartment.oc1..test" + + utils_oci.init_client(oci.generative_ai_inference.GenerativeAiInferenceClient, config) + + call_kwargs = mock_client_class.call_args.kwargs + assert "inference.generativeai.us-chicago-1.oci.oraclecloud.com" in call_kwargs["service_endpoint"] + + +class TestGetNamespace: + """Tests for the get_namespace function.""" + + @patch("server.api.utils.oci.init_client") + def test_get_namespace_success(self, mock_init_client, make_oci_config): + """get_namespace should return namespace on success.""" + mock_client = MagicMock() + mock_client.get_namespace.return_value.data = "test-namespace" + mock_init_client.return_value = mock_client + config = make_oci_config() + + result = utils_oci.get_namespace(config) + + assert result == "test-namespace" + assert config.namespace == "test-namespace" + + @patch("server.api.utils.oci.init_client") + def test_get_namespace_raises_on_service_error(self, mock_init_client, make_oci_config): + """get_namespace should raise OciException on service error.""" + mock_client = MagicMock() + mock_client.get_namespace.side_effect = oci.exceptions.ServiceError( + status=401, code="NotAuthenticated", headers={}, message="Not authenticated" + ) + mock_init_client.return_value = mock_client + config = make_oci_config() + + with pytest.raises(OciException) as exc_info: + utils_oci.get_namespace(config) + + assert exc_info.value.status_code == 401 + + @patch("server.api.utils.oci.init_client") + def test_get_namespace_raises_on_file_not_found(self, mock_init_client, make_oci_config): + """get_namespace should raise OciException on file not found.""" + mock_init_client.side_effect = FileNotFoundError("Key file not found") + config = make_oci_config() + + with pytest.raises(OciException) as exc_info: + utils_oci.get_namespace(config) + + assert exc_info.value.status_code == 400 + + +class TestGetRegions: + """Tests for the get_regions function.""" + + @patch("server.api.utils.oci.init_client") + def test_get_regions_returns_list(self, mock_init_client, make_oci_config): + """get_regions should return list of region subscriptions.""" + mock_region = MagicMock() + mock_region.is_home_region = True + mock_region.region_key = "IAD" + mock_region.region_name = "us-ashburn-1" + mock_region.status = "READY" + + mock_client = MagicMock() + mock_client.list_region_subscriptions.return_value.data = [mock_region] + mock_init_client.return_value = mock_client + config = make_oci_config() + config.tenancy = "test-tenancy" + + result = utils_oci.get_regions(config) + + assert len(result) == 1 + assert result[0]["region_name"] == "us-ashburn-1" + assert result[0]["is_home_region"] is True + + +class TestGetGenaiModels: + """Tests for the get_genai_models function.""" + + def test_get_genai_models_raises_without_compartment(self, make_oci_config): + """get_genai_models should raise OciException without compartment_id.""" + config = make_oci_config() + config.genai_compartment_id = None + + with pytest.raises(OciException) as exc_info: + utils_oci.get_genai_models(config) + + assert exc_info.value.status_code == 400 + assert "genai_compartment_id" in exc_info.value.detail + + def test_get_genai_models_regional_raises_without_region(self, make_oci_config): + """get_genai_models should raise OciException without region when regional=True.""" + config = make_oci_config() + config.genai_compartment_id = "ocid1.compartment.oc1..test" + config.genai_region = None + + with pytest.raises(OciException) as exc_info: + utils_oci.get_genai_models(config, regional=True) + + assert exc_info.value.status_code == 400 + assert "genai_region" in exc_info.value.detail + + @patch("server.api.utils.oci.init_client") + def test_get_genai_models_returns_models(self, mock_init_client, make_oci_config): + """get_genai_models should return list of GenAI models.""" + mock_model = MagicMock() + mock_model.display_name = "cohere.command-r-plus" + mock_model.capabilities = ["TEXT_GENERATION"] + mock_model.vendor = "cohere" + mock_model.id = "ocid1.model.oc1..test" + mock_model.time_deprecated = None + mock_model.time_dedicated_retired = None + mock_model.time_on_demand_retired = None + + mock_response = MagicMock() + mock_response.data.items = [mock_model] + + mock_client = MagicMock() + mock_client.list_models.return_value = mock_response + mock_init_client.return_value = mock_client + + config = make_oci_config(genai_region="us-chicago-1") + config.genai_compartment_id = "ocid1.compartment.oc1..test" + + result = utils_oci.get_genai_models(config, regional=True) + + assert len(result) == 1 + assert result[0]["model_name"] == "cohere.command-r-plus" + + +class TestGetCompartments: + """Tests for the get_compartments function.""" + + @patch("server.api.utils.oci.init_client") + def test_get_compartments_returns_dict(self, mock_init_client, make_oci_config): + """get_compartments should return dict of compartment paths.""" + mock_compartment = MagicMock() + mock_compartment.id = "ocid1.compartment.oc1..test" + mock_compartment.name = "TestCompartment" + mock_compartment.compartment_id = None # Root level + + mock_client = MagicMock() + mock_client.list_compartments.return_value.data = [mock_compartment] + mock_init_client.return_value = mock_client + + config = make_oci_config() + config.tenancy = "test-tenancy" + + result = utils_oci.get_compartments(config) + + assert "TestCompartment" in result + assert result["TestCompartment"] == "ocid1.compartment.oc1..test" + + +class TestGetBuckets: + """Tests for the get_buckets function.""" + + @patch("server.api.utils.oci.init_client") + def test_get_buckets_returns_list(self, mock_init_client, make_oci_config): + """get_buckets should return list of bucket names.""" + mock_bucket = MagicMock() + mock_bucket.name = "test-bucket" + mock_bucket.freeform_tags = {} + + mock_client = MagicMock() + mock_client.list_buckets.return_value.data = [mock_bucket] + mock_init_client.return_value = mock_client + + config = make_oci_config() + config.namespace = "test-namespace" + + result = utils_oci.get_buckets("compartment-id", config) + + assert result == ["test-bucket"] + + @patch("server.api.utils.oci.init_client") + def test_get_buckets_excludes_genai_chunk_buckets(self, mock_init_client, make_oci_config): + """get_buckets should exclude buckets with genai_chunk=true tag.""" + mock_bucket1 = MagicMock() + mock_bucket1.name = "normal-bucket" + mock_bucket1.freeform_tags = {} + + mock_bucket2 = MagicMock() + mock_bucket2.name = "chunk-bucket" + mock_bucket2.freeform_tags = {"genai_chunk": "true"} + + mock_client = MagicMock() + mock_client.list_buckets.return_value.data = [mock_bucket1, mock_bucket2] + mock_init_client.return_value = mock_client + + config = make_oci_config() + config.namespace = "test-namespace" + + result = utils_oci.get_buckets("compartment-id", config) + + assert result == ["normal-bucket"] + + @patch("server.api.utils.oci.init_client") + def test_get_buckets_raises_on_service_error(self, mock_init_client, make_oci_config): + """get_buckets should raise OciException on service error.""" + mock_client = MagicMock() + mock_client.list_buckets.side_effect = oci.exceptions.ServiceError( + status=401, code="NotAuthenticated", headers={}, message="Not authenticated" + ) + mock_init_client.return_value = mock_client + + config = make_oci_config() + config.namespace = "test-namespace" + + with pytest.raises(OciException) as exc_info: + utils_oci.get_buckets("compartment-id", config) + + assert exc_info.value.status_code == 401 + + +class TestGetBucketObjects: + """Tests for the get_bucket_objects function.""" + + @patch("server.api.utils.oci.init_client") + def test_get_bucket_objects_returns_names(self, mock_init_client, make_oci_config): + """get_bucket_objects should return list of object names.""" + mock_obj = MagicMock() + mock_obj.name = "document.pdf" + + mock_response = MagicMock() + mock_response.data.objects = [mock_obj] + + mock_client = MagicMock() + mock_client.list_objects.return_value = mock_response + mock_init_client.return_value = mock_client + + config = make_oci_config() + config.namespace = "test-namespace" + + result = utils_oci.get_bucket_objects("test-bucket", config) + + assert result == ["document.pdf"] + + @patch("server.api.utils.oci.init_client") + def test_get_bucket_objects_returns_empty_on_not_found(self, mock_init_client, make_oci_config): + """get_bucket_objects should return empty list on service error.""" + mock_client = MagicMock() + mock_client.list_objects.side_effect = oci.exceptions.ServiceError( + status=404, code="BucketNotFound", headers={}, message="Bucket not found" + ) + mock_init_client.return_value = mock_client + + config = make_oci_config() + config.namespace = "test-namespace" + + result = utils_oci.get_bucket_objects("nonexistent-bucket", config) + + assert result == [] + + +class TestGetBucketObjectsWithMetadata: + """Tests for the get_bucket_objects_with_metadata function.""" + + @patch("server.api.utils.oci.init_client") + def test_get_bucket_objects_with_metadata_returns_supported_files(self, mock_init_client, make_oci_config): + """get_bucket_objects_with_metadata should return only supported file types.""" + mock_pdf = MagicMock() + mock_pdf.name = "document.pdf" + mock_pdf.size = 1000 + mock_pdf.etag = "abc123" + mock_pdf.time_modified = datetime(2024, 1, 1, 12, 0, 0) + mock_pdf.md5 = "md5hash" + + mock_exe = MagicMock() + mock_exe.name = "program.exe" + mock_exe.size = 2000 + mock_exe.etag = "def456" + mock_exe.time_modified = datetime(2024, 1, 1, 12, 0, 0) + mock_exe.md5 = "md5hash2" + + mock_response = MagicMock() + mock_response.data.objects = [mock_pdf, mock_exe] + + mock_client = MagicMock() + mock_client.list_objects.return_value = mock_response + mock_init_client.return_value = mock_client + + config = make_oci_config() + config.namespace = "test-namespace" + + result = utils_oci.get_bucket_objects_with_metadata("test-bucket", config) + + assert len(result) == 1 + assert result[0]["name"] == "document.pdf" + assert result[0]["extension"] == "pdf" + + +class TestDetectChangedObjects: + """Tests for the detect_changed_objects function.""" + + def test_detect_new_objects(self): + """detect_changed_objects should identify new objects.""" + current_objects = [{"name": "new_file.pdf", "etag": "abc123", "time_modified": "2024-01-01T12:00:00"}] + processed_objects = {} + + new, modified = utils_oci.detect_changed_objects(current_objects, processed_objects) + + assert len(new) == 1 + assert len(modified) == 0 + assert new[0]["name"] == "new_file.pdf" + + def test_detect_modified_objects(self): + """detect_changed_objects should identify modified objects.""" + current_objects = [{"name": "existing.pdf", "etag": "new_etag", "time_modified": "2024-01-02T12:00:00"}] + processed_objects = {"existing.pdf": {"etag": "old_etag", "time_modified": "2024-01-01T12:00:00"}} + + new, modified = utils_oci.detect_changed_objects(current_objects, processed_objects) + + assert len(new) == 0 + assert len(modified) == 1 + assert modified[0]["name"] == "existing.pdf" + + def test_detect_unchanged_objects(self): + """detect_changed_objects should not flag unchanged objects.""" + current_objects = [{"name": "existing.pdf", "etag": "same_etag", "time_modified": "2024-01-01T12:00:00"}] + processed_objects = {"existing.pdf": {"etag": "same_etag", "time_modified": "2024-01-01T12:00:00"}} + + new, modified = utils_oci.detect_changed_objects(current_objects, processed_objects) + + assert len(new) == 0 + assert len(modified) == 0 + + def test_detect_skips_old_format_metadata(self): + """detect_changed_objects should skip objects with old format metadata.""" + current_objects = [{"name": "old_format.pdf", "etag": "new_etag", "time_modified": "2024-01-02T12:00:00"}] + processed_objects = {"old_format.pdf": {"etag": None, "time_modified": None}} + + new, modified = utils_oci.detect_changed_objects(current_objects, processed_objects) + + assert len(new) == 0 + assert len(modified) == 0 + + +class TestGetObject: + """Tests for the get_object function.""" + + @patch("server.api.utils.oci.init_client") + def test_get_object_downloads_file(self, mock_init_client, make_oci_config, tmp_path): + """get_object should download file to directory.""" + mock_response = MagicMock() + mock_response.data.raw.stream.return_value = [b"file content"] + + mock_client = MagicMock() + mock_client.get_object.return_value = mock_response + mock_init_client.return_value = mock_client + + config = make_oci_config() + config.namespace = "test-namespace" + + result = utils_oci.get_object(str(tmp_path), "folder/document.pdf", "test-bucket", config) + + assert result == str(tmp_path / "document.pdf") + assert (tmp_path / "document.pdf").exists() + assert (tmp_path / "document.pdf").read_bytes() == b"file content" + + +class TestInitGenaiClient: + """Tests for the init_genai_client function.""" + + @patch("server.api.utils.oci.init_client") + def test_init_genai_client_calls_init_client(self, mock_init_client, make_oci_config): + """init_genai_client should call init_client with correct type.""" + mock_client = MagicMock() + mock_init_client.return_value = mock_client + config = make_oci_config() + + result = utils_oci.init_genai_client(config) + + mock_init_client.assert_called_once_with(oci.generative_ai_inference.GenerativeAiInferenceClient, config) + assert result == mock_client + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured.""" + assert hasattr(utils_oci, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert utils_oci.logger.name == "api.utils.oci" diff --git a/test/unit/server/api/utils/test_utils_settings.py b/test/unit/server/api/utils/test_utils_settings.py new file mode 100644 index 00000000..1f84ec41 --- /dev/null +++ b/test/unit/server/api/utils/test_utils_settings.py @@ -0,0 +1,352 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/utils/settings.py +Tests for settings utility functions. +""" + +# pylint: disable=too-few-public-methods + +import json +from unittest.mock import patch, MagicMock + +import pytest + +from server.api.utils import settings as utils_settings +from server.api.utils.settings import bootstrap + + +class TestCreateClient: + """Tests for the create_client function.""" + + @patch("server.api.utils.settings.bootstrap.SETTINGS_OBJECTS") + def test_create_client_success(self, mock_settings, make_settings): + """create_client should create new client from default settings.""" + default_settings = make_settings(client="default") + # Return new iterator each time __iter__ is called (consumed twice: any() and next()) + mock_settings.__iter__ = lambda _: iter([default_settings]) + mock_settings.__bool__ = lambda _: True + mock_settings.append = MagicMock() + + result = utils_settings.create_client("new_client") + + assert result.client == "new_client" + mock_settings.append.assert_called_once() + + @patch("server.api.utils.settings.bootstrap.SETTINGS_OBJECTS") + def test_create_client_raises_on_existing(self, mock_settings, make_settings): + """create_client should raise ValueError if client exists.""" + existing_settings = make_settings(client="existing") + mock_settings.__iter__ = lambda _: iter([existing_settings]) + + with pytest.raises(ValueError) as exc_info: + utils_settings.create_client("existing") + + assert "already exists" in str(exc_info.value) + + +class TestGetClient: + """Tests for the get_client function.""" + + @patch("server.api.utils.settings.bootstrap.SETTINGS_OBJECTS") + def test_get_client_success(self, mock_settings, make_settings): + """get_client should return client settings.""" + client_settings = make_settings(client="test_client") + mock_settings.__iter__ = lambda _: iter([client_settings]) + + result = utils_settings.get_client("test_client") + + assert result.client == "test_client" + + @patch("server.api.utils.settings.bootstrap.SETTINGS_OBJECTS") + def test_get_client_raises_on_not_found(self, mock_settings): + """get_client should raise ValueError if client not found.""" + mock_settings.__iter__ = lambda _: iter([]) + + with pytest.raises(ValueError) as exc_info: + utils_settings.get_client("nonexistent") + + assert "not found" in str(exc_info.value) + + +class TestUpdateClient: + """Tests for the update_client function.""" + + @patch("server.api.utils.settings.get_client") + @patch("server.api.utils.settings.bootstrap.SETTINGS_OBJECTS") + def test_update_client_success(self, mock_settings, mock_get_client, make_settings): + """update_client should update and return client settings.""" + old_settings = make_settings(client="test_client") + new_settings = make_settings(client="other") + + mock_get_client.side_effect = [old_settings, new_settings] + mock_settings.remove = MagicMock() + mock_settings.append = MagicMock() + + utils_settings.update_client(new_settings, "test_client") + + mock_settings.remove.assert_called_once_with(old_settings) + mock_settings.append.assert_called_once() + + +class TestGetMcpPromptsWithOverrides: + """Tests for the get_mcp_prompts_with_overrides function.""" + + @pytest.mark.asyncio + @patch("server.api.utils.settings.utils_mcp.list_prompts") + @patch("server.api.utils.settings.defaults") + @patch("server.api.utils.settings.cache.get_override") + async def test_get_mcp_prompts_with_overrides_success(self, mock_get_override, mock_defaults, mock_list_prompts): + """get_mcp_prompts_with_overrides should return list of MCPPrompt.""" + mock_prompt = MagicMock() + mock_prompt.name = "optimizer_test-prompt" + mock_prompt.title = "Test Prompt" + mock_prompt.description = "Test description" + mock_prompt.meta = {"_fastmcp": {"tags": ["rag", "chat"]}} + + mock_list_prompts.return_value = [mock_prompt] + + mock_default_func = MagicMock() + mock_default_func.return_value.content.text = "Default text" + mock_defaults.optimizer_test_prompt = mock_default_func + + mock_get_override.return_value = None + + mock_mcp_engine = MagicMock() + + result = await utils_settings.get_mcp_prompts_with_overrides(mock_mcp_engine) + + assert len(result) == 1 + assert result[0].name == "optimizer_test-prompt" + assert result[0].text == "Default text" + + @pytest.mark.asyncio + @patch("server.api.utils.settings.utils_mcp.list_prompts") + @patch("server.api.utils.settings.defaults") + @patch("server.api.utils.settings.cache.get_override") + async def test_get_mcp_prompts_uses_override(self, mock_get_override, mock_defaults, mock_list_prompts): + """get_mcp_prompts_with_overrides should use override text when available.""" + mock_prompt = MagicMock() + mock_prompt.name = "optimizer_test-prompt" + mock_prompt.title = None + mock_prompt.description = None + mock_prompt.meta = None + + mock_list_prompts.return_value = [mock_prompt] + + mock_default_func = MagicMock() + mock_default_func.return_value.content.text = "Default text" + mock_defaults.optimizer_test_prompt = mock_default_func + + mock_get_override.return_value = "Override text" + + mock_mcp_engine = MagicMock() + + result = await utils_settings.get_mcp_prompts_with_overrides(mock_mcp_engine) + + assert result[0].text == "Override text" + + @pytest.mark.asyncio + @patch("server.api.utils.settings.utils_mcp.list_prompts") + async def test_get_mcp_prompts_filters_non_optimizer(self, mock_list_prompts): + """get_mcp_prompts_with_overrides should filter out non-optimizer prompts.""" + mock_prompt1 = MagicMock() + mock_prompt1.name = "optimizer_test" + mock_prompt1.title = None + mock_prompt1.description = None + mock_prompt1.meta = None + + mock_prompt2 = MagicMock() + mock_prompt2.name = "other_prompt" + + mock_list_prompts.return_value = [mock_prompt1, mock_prompt2] + + mock_mcp_engine = MagicMock() + + with patch("server.api.utils.settings.defaults") as mock_defaults: + mock_defaults.optimizer_test = None + with patch("server.api.utils.settings.cache.get_override", return_value=None): + result = await utils_settings.get_mcp_prompts_with_overrides(mock_mcp_engine) + + assert len(result) == 1 + assert result[0].name == "optimizer_test" + + +class TestGetServer: + """Tests for the get_server function.""" + + @pytest.mark.asyncio + @patch("server.api.utils.settings.get_mcp_prompts_with_overrides") + @patch("server.api.utils.settings.bootstrap.DATABASE_OBJECTS", []) + @patch("server.api.utils.settings.bootstrap.MODEL_OBJECTS", []) + @patch("server.api.utils.settings.bootstrap.OCI_OBJECTS", []) + async def test_get_server_returns_config(self, mock_get_prompts): + """get_server should return server configuration dict.""" + mock_get_prompts.return_value = [] + mock_mcp_engine = MagicMock() + + result = await utils_settings.get_server(mock_mcp_engine) + + assert "database_configs" in result + assert "model_configs" in result + assert "oci_configs" in result + assert "prompt_configs" in result + + +class TestUpdateServer: + """Tests for the update_server function.""" + + @patch("server.api.utils.settings.bootstrap.DATABASE_OBJECTS", []) + @patch("server.api.utils.settings.bootstrap.MODEL_OBJECTS", []) + @patch("server.api.utils.settings.bootstrap.OCI_OBJECTS", []) + def test_update_server_updates_databases(self, make_database, make_settings): + """update_server should update database objects.""" + config_data = { + "client_settings": make_settings().model_dump(), + "database_configs": [make_database(name="NEW_DB").model_dump()], + } + + utils_settings.update_server(config_data) + + assert len(bootstrap.DATABASE_OBJECTS) == 1 + + @patch("server.api.utils.settings._load_prompt_configs") + @patch("server.api.utils.settings.bootstrap.DATABASE_OBJECTS", []) + @patch("server.api.utils.settings.bootstrap.MODEL_OBJECTS", []) + @patch("server.api.utils.settings.bootstrap.OCI_OBJECTS", []) + def test_update_server_loads_prompt_configs(self, mock_load_prompts, make_settings): + """update_server should load prompt configs.""" + config_data = { + "client_settings": make_settings().model_dump(), + "prompt_configs": [{"name": "test", "title": "Test Title", "text": "Test text"}], + } + + utils_settings.update_server(config_data) + + mock_load_prompts.assert_called_once_with(config_data) + + +class TestLoadPromptOverride: # pylint: disable=protected-access + """Tests for the _load_prompt_override function.""" + + @patch("server.api.utils.settings.cache.set_override") + def test_load_prompt_override_with_text(self, mock_set_override): + """_load_prompt_override should set cache with text.""" + prompt = {"name": "test_prompt", "text": "Test text"} + + result = utils_settings._load_prompt_override(prompt) + + assert result is True + mock_set_override.assert_called_once_with("test_prompt", "Test text") + + @patch("server.api.utils.settings.cache.set_override") + def test_load_prompt_override_without_text(self, mock_set_override): + """_load_prompt_override should return False without text.""" + prompt = {"name": "test_prompt"} + + result = utils_settings._load_prompt_override(prompt) + + assert result is False + mock_set_override.assert_not_called() + + +class TestLoadPromptConfigs: # pylint: disable=protected-access + """Tests for the _load_prompt_configs function.""" + + @patch("server.api.utils.settings._load_prompt_override") + def test_load_prompt_configs_with_prompts(self, mock_load_override): + """_load_prompt_configs should load all prompts.""" + mock_load_override.return_value = True + config_data = {"prompt_configs": [{"name": "p1", "text": "t1"}, {"name": "p2", "text": "t2"}]} + + utils_settings._load_prompt_configs(config_data) + + assert mock_load_override.call_count == 2 + + @patch("server.api.utils.settings._load_prompt_override") + def test_load_prompt_configs_without_key(self, mock_load_override): + """_load_prompt_configs should handle missing prompt_configs key.""" + config_data = {} + + utils_settings._load_prompt_configs(config_data) + + mock_load_override.assert_not_called() + + @patch("server.api.utils.settings._load_prompt_override") + def test_load_prompt_configs_empty_list(self, mock_load_override): + """_load_prompt_configs should handle empty prompt_configs.""" + config_data = {"prompt_configs": []} + + utils_settings._load_prompt_configs(config_data) + + mock_load_override.assert_not_called() + + +class TestLoadConfigFromJsonData: + """Tests for the load_config_from_json_data function.""" + + @patch("server.api.utils.settings.update_server") + @patch("server.api.utils.settings.update_client") + def test_load_config_from_json_data_with_client(self, mock_update_client, mock_update_server, make_settings): + """load_config_from_json_data should update specific client.""" + config_data = {"client_settings": make_settings().model_dump()} + + utils_settings.load_config_from_json_data(config_data, client="test_client") + + mock_update_server.assert_called_once() + mock_update_client.assert_called_once() + + @patch("server.api.utils.settings.update_server") + @patch("server.api.utils.settings.update_client") + def test_load_config_from_json_data_without_client(self, mock_update_client, mock_update_server, make_settings): + """load_config_from_json_data should update server and default when no client.""" + config_data = {"client_settings": make_settings().model_dump()} + + utils_settings.load_config_from_json_data(config_data, client=None) + + mock_update_server.assert_called_once() + assert mock_update_client.call_count == 2 # "server" and "default" + + @patch("server.api.utils.settings.update_server") + def test_load_config_from_json_data_raises_missing_settings(self, _mock_update_server): + """load_config_from_json_data should raise KeyError if missing client_settings.""" + config_data = {} + + with pytest.raises(KeyError) as exc_info: + utils_settings.load_config_from_json_data(config_data) + + assert "client_settings" in str(exc_info.value) + + +class TestReadConfigFromJsonFile: + """Tests for the read_config_from_json_file function.""" + + @patch.dict("os.environ", {"CONFIG_FILE": "/path/to/config.json"}) + @patch("os.path.isfile", return_value=True) + @patch("os.access", return_value=True) + @patch("builtins.open") + def test_read_config_from_json_file_success(self, mock_open, mock_access, mock_isfile, make_settings): + """read_config_from_json_file should return Configuration.""" + _ = (mock_access, mock_isfile) # Used to suppress unused argument warning + + config_data = {"client_settings": make_settings().model_dump()} + mock_open.return_value.__enter__.return_value.read.return_value = json.dumps(config_data) + + # Mock json.load + with patch("json.load", return_value=config_data): + result = utils_settings.read_config_from_json_file() + + assert result is not None + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured.""" + assert hasattr(utils_settings, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert utils_settings.logger.name == "api.core.settings" diff --git a/test/unit/server/api/utils/test_utils_testbed.py b/test/unit/server/api/utils/test_utils_testbed.py new file mode 100644 index 00000000..f68ab797 --- /dev/null +++ b/test/unit/server/api/utils/test_utils_testbed.py @@ -0,0 +1,324 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/utils/testbed.py +Tests for testbed utility functions. + +Uses hybrid approach: +- Real Oracle database for testbed table creation and querying +- Mocks for external dependencies (PDF processing, LLM calls) +""" + +# pylint: disable=too-few-public-methods + +import json +from unittest.mock import patch, MagicMock + +import pytest + +from server.api.utils import testbed as utils_testbed + + +class TestJsonlToJsonContent: + """Tests for the jsonl_to_json_content function.""" + + def test_jsonl_to_json_content_single_json(self): + """Should parse single JSON object.""" + content = '{"question": "What is AI?", "answer": "Artificial Intelligence"}' + + result = utils_testbed.jsonl_to_json_content(content) + + parsed = json.loads(result) + assert parsed["question"] == "What is AI?" + + def test_jsonl_to_json_content_jsonl(self): + """Should parse JSONL (multiple lines).""" + content = '{"q": "Q1"}\n{"q": "Q2"}' + + result = utils_testbed.jsonl_to_json_content(content) + + parsed = json.loads(result) + assert len(parsed) == 2 + + def test_jsonl_to_json_content_bytes(self): + """Should handle bytes input.""" + content = b'{"question": "test"}' + + result = utils_testbed.jsonl_to_json_content(content) + + parsed = json.loads(result) + assert parsed["question"] == "test" + + def test_jsonl_to_json_content_single_jsonl(self): + """Should handle single line JSONL.""" + content = '{"question": "test"}\n' + + result = utils_testbed.jsonl_to_json_content(content) + + parsed = json.loads(result) + assert parsed["question"] == "test" + + def test_jsonl_to_json_content_invalid(self): + """Should raise ValueError for invalid content.""" + content = "not valid json at all" + + with pytest.raises(ValueError) as exc_info: + utils_testbed.jsonl_to_json_content(content) + + assert "Invalid JSONL content" in str(exc_info.value) + + +class TestCreateTestsetObjects: + """Tests for the create_testset_objects function. + + Uses mocks since DDL (CREATE TABLE) causes implicit commits in Oracle, + which breaks savepoint-based test isolation. + """ + + @patch("server.api.utils.testbed.utils_databases.execute_sql") + def test_create_testset_objects_executes_ddl(self, mock_execute): + """Should execute SQL to create testset tables.""" + mock_conn = MagicMock() + + utils_testbed.create_testset_objects(mock_conn) + + # Should execute 3 DDL statements (testsets, testset_qa, evaluations) + assert mock_execute.call_count == 3 + + +class TestGetTestsets: + """Tests for the get_testsets function. + + Uses mocks since the function may trigger DDL which causes implicit commits. + """ + + @patch("server.api.utils.testbed.utils_databases.execute_sql") + def test_get_testsets_returns_list(self, mock_execute): + """Should return list of TestSets.""" + mock_conn = MagicMock() + # Return empty result set + mock_execute.return_value = [] + + result = utils_testbed.get_testsets(mock_conn) + + assert isinstance(result, list) + assert len(result) == 0 + + @patch("server.api.utils.testbed.utils_databases.execute_sql") + def test_get_testsets_creates_tables_on_first_call(self, mock_execute): + """Should create tables if they don't exist.""" + mock_conn = MagicMock() + # First call returns None (which causes TypeError during unpacking), + # then 3 DDL calls for table creation, then final query returns [] + mock_execute.side_effect = [None, None, None, None, []] + + result = utils_testbed.get_testsets(mock_conn) + + assert isinstance(result, list) + + +class TestGetTestsetQa: + """Tests for the get_testset_qa function.""" + + @patch("server.api.utils.testbed.utils_databases.execute_sql") + def test_get_testset_qa_returns_qa(self, mock_execute): + """Should return TestSetQA object.""" + mock_execute.return_value = [('{"question": "Q1"}',)] + mock_conn = MagicMock() + + result = utils_testbed.get_testset_qa(mock_conn, "abc123") + + assert len(result.qa_data) == 1 + + +class TestGetEvaluations: + """Tests for the get_evaluations function.""" + + @patch("server.api.utils.testbed.utils_databases.execute_sql") + def test_get_evaluations_returns_list(self, mock_execute): + """Should return list of Evaluation objects.""" + mock_eid = MagicMock() + mock_eid.hex.return_value = "eval123" + mock_execute.return_value = [(mock_eid, "2024-01-01", 0.85)] + mock_conn = MagicMock() + + result = utils_testbed.get_evaluations(mock_conn, "tid123") + + assert len(result) == 1 + assert result[0].correctness == 0.85 + + @patch("server.api.utils.testbed.utils_databases.execute_sql") + @patch("server.api.utils.testbed.create_testset_objects") + def test_get_evaluations_creates_tables_on_error(self, mock_create, mock_execute): + """Should create tables if TypeError occurs.""" + mock_execute.return_value = None + mock_conn = MagicMock() + + result = utils_testbed.get_evaluations(mock_conn, "tid123") + + mock_create.assert_called_once() + assert result == [] + + +class TestDeleteQa: + """Tests for the delete_qa function.""" + + @patch("server.api.utils.testbed.utils_databases.execute_sql") + def test_delete_qa_executes_sql(self, mock_execute): + """Should execute DELETE SQL.""" + mock_conn = MagicMock() + + utils_testbed.delete_qa(mock_conn, "tid123") + + mock_execute.assert_called_once() + mock_conn.commit.assert_called_once() + + +class TestUpsertQa: + """Tests for the upsert_qa function.""" + + @patch("server.api.utils.testbed.utils_databases.execute_sql") + def test_upsert_qa_single_qa(self, mock_execute): + """Should handle single QA object.""" + mock_execute.return_value = "tid123" + mock_conn = MagicMock() + json_data = '{"question": "Q1", "answer": "A1"}' + + result = utils_testbed.upsert_qa(mock_conn, "TestSet", "2024-01-01T00:00:00.000", json_data) + + mock_execute.assert_called_once() + assert result == "tid123" + + @patch("server.api.utils.testbed.utils_databases.execute_sql") + def test_upsert_qa_multiple_qa(self, mock_execute): + """Should handle multiple QA objects.""" + mock_execute.return_value = "tid123" + mock_conn = MagicMock() + json_data = '[{"q": "Q1"}, {"q": "Q2"}]' + + utils_testbed.upsert_qa(mock_conn, "TestSet", "2024-01-01T00:00:00.000", json_data) + + mock_execute.assert_called_once() + + +class TestInsertEvaluation: + """Tests for the insert_evaluation function.""" + + @patch("server.api.utils.testbed.utils_databases.execute_sql") + def test_insert_evaluation_executes_sql(self, mock_execute): + """Should execute INSERT SQL.""" + mock_execute.return_value = "eid123" + mock_conn = MagicMock() + + result = utils_testbed.insert_evaluation( + mock_conn, "tid123", "2024-01-01T00:00:00.000", 0.85, '{"model": "gpt-4"}', b"report_data" + ) + + mock_execute.assert_called_once() + assert result == "eid123" + + +class TestLoadAndSplit: + """Tests for the load_and_split function.""" + + @patch("server.api.utils.testbed.PdfReader") + @patch("server.api.utils.testbed.SentenceSplitter") + def test_load_and_split_processes_pdf(self, mock_splitter, mock_reader): + """Should load PDF and split into nodes.""" + mock_page = MagicMock() + mock_page.extract_text.return_value = "Page content" + mock_reader.return_value.pages = [mock_page] + + mock_splitter_instance = MagicMock() + mock_splitter_instance.return_value = ["node1", "node2"] + mock_splitter.return_value = mock_splitter_instance + + utils_testbed.load_and_split("/path/to/doc.pdf", chunk_size=1024) + + mock_reader.assert_called_once_with("/path/to/doc.pdf") + mock_splitter.assert_called_once_with(chunk_size=1024) + + +class TestBuildKnowledgeBase: + """Tests for the build_knowledge_base function.""" + + @patch("server.api.utils.testbed.utils_models.get_litellm_config") + @patch("server.api.utils.testbed.set_llm_model") + @patch("server.api.utils.testbed.set_embedding_model") + @patch("server.api.utils.testbed.KnowledgeBase") + @patch("server.api.utils.testbed.generate_testset") + def test_build_knowledge_base_success( + self, mock_generate, mock_kb, mock_set_embed, mock_set_llm, mock_get_config, make_oci_config + ): + """Should create knowledge base and generate testset.""" + mock_get_config.return_value = {"api_key": "test"} + mock_testset = MagicMock() + mock_generate.return_value = mock_testset + + mock_text_node = MagicMock() + mock_text_node.text = "Sample text" + text_nodes = [mock_text_node] + + oci_config = make_oci_config() + + result = utils_testbed.build_knowledge_base( + text_nodes, + questions=5, + ll_model="openai/gpt-4", + embed_model="openai/text-embedding-3-small", + oci_config=oci_config, + ) + + mock_set_llm.assert_called_once() + mock_set_embed.assert_called_once() + mock_kb.assert_called_once() + mock_generate.assert_called_once() + assert result == mock_testset + + +class TestProcessReport: + """Tests for the process_report function.""" + + @patch("server.api.utils.testbed.utils_databases.execute_sql") + @patch("server.api.utils.testbed.pickle.loads") + def test_process_report_success(self, mock_pickle, mock_execute, make_settings): + """Should process evaluation report.""" + mock_eid = MagicMock() + mock_eid.hex.return_value = "eid123" + + mock_report = MagicMock() + mock_report.to_pandas.return_value = MagicMock(to_dict=MagicMock(return_value={})) + mock_report.correctness_by_topic.return_value = MagicMock(to_dict=MagicMock(return_value={})) + mock_report.failures = MagicMock(to_dict=MagicMock(return_value={})) + mock_pickle.return_value = mock_report + + # Settings needs to be a valid Settings object (or dict with required fields) + settings_data = make_settings().model_dump() + mock_execute.return_value = [ + { + "EID": mock_eid, + "EVALUATED": "2024-01-01", + "CORRECTNESS": 0.85, + "SETTINGS": settings_data, + "RAG_REPORT": b"data", + } + ] + mock_conn = MagicMock() + + result = utils_testbed.process_report(mock_conn, "eid123") + + assert result.eid == "eid123" + assert result.correctness == 0.85 + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured.""" + assert hasattr(utils_testbed, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert utils_testbed.logger.name == "api.utils.testbed" diff --git a/test/unit/server/api/utils/test_utils_webscrape.py b/test/unit/server/api/utils/test_utils_webscrape.py new file mode 100644 index 00000000..cd042992 --- /dev/null +++ b/test/unit/server/api/utils/test_utils_webscrape.py @@ -0,0 +1,419 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/utils/webscrape.py +Tests for web scraping and content extraction utilities. +""" + +# pylint: disable=too-few-public-methods + +from test.unit.server.api.conftest import create_mock_aiohttp_session +from unittest.mock import patch, AsyncMock + +import pytest +from bs4 import BeautifulSoup + +from server.api.utils import webscrape + + +class TestNormalizeWs: + """Tests for the normalize_ws function.""" + + def test_normalize_ws_removes_extra_spaces(self): + """normalize_ws should collapse multiple spaces into one.""" + result = webscrape.normalize_ws("Hello world") + assert result == "Hello world" + + def test_normalize_ws_removes_newlines(self): + """normalize_ws should replace newlines with spaces.""" + result = webscrape.normalize_ws("Hello\n\nworld") + assert result == "Hello world" + + def test_normalize_ws_strips_whitespace(self): + """normalize_ws should strip leading/trailing whitespace.""" + result = webscrape.normalize_ws(" Hello world ") + assert result == "Hello world" + + def test_normalize_ws_handles_tabs(self): + """normalize_ws should handle tab characters.""" + result = webscrape.normalize_ws("Hello\t\tworld") + assert result == "Hello world" + + def test_normalize_ws_normalizes_unicode(self): + """normalize_ws should normalize unicode characters.""" + # NFKC normalization should convert full-width to half-width + result = webscrape.normalize_ws("Hello") # Full-width characters + assert result == "Hello" + + def test_normalize_ws_empty_string(self): + """normalize_ws should handle empty string.""" + result = webscrape.normalize_ws("") + assert result == "" + + +class TestCleanSoup: + """Tests for the clean_soup function.""" + + def test_clean_soup_removes_script_tags(self): + """clean_soup should remove script tags.""" + html = "

Content

" + soup = BeautifulSoup(html, "html.parser") + + webscrape.clean_soup(soup) + + assert soup.find("script") is None + assert soup.find("p") is not None + + def test_clean_soup_removes_style_tags(self): + """clean_soup should remove style tags.""" + html = "

Content

" + soup = BeautifulSoup(html, "html.parser") + + webscrape.clean_soup(soup) + + assert soup.find("style") is None + + def test_clean_soup_removes_noscript_tags(self): + """clean_soup should remove noscript tags.""" + html = "

Content

" + soup = BeautifulSoup(html, "html.parser") + + webscrape.clean_soup(soup) + + assert soup.find("noscript") is None + + def test_clean_soup_removes_nav_elements(self): + """clean_soup should remove navigation elements.""" + html = '

Content

' + soup = BeautifulSoup(html, "html.parser") + + webscrape.clean_soup(soup) + + assert soup.find("nav") is None + + def test_clean_soup_removes_elements_by_class(self): + """clean_soup should remove elements with bad class names.""" + html = '

Content

' + soup = BeautifulSoup(html, "html.parser") + + webscrape.clean_soup(soup) + + assert soup.find(class_="footer") is None + + def test_clean_soup_preserves_content(self): + """clean_soup should preserve main content.""" + html = "

Important content

" + soup = BeautifulSoup(html, "html.parser") + + webscrape.clean_soup(soup) + + assert soup.find("p") is not None + assert "Important content" in soup.get_text() + + +class TestHeadingLevel: + """Tests for the heading_level function.""" + + def test_heading_level_h1(self): + """heading_level should return 1 for h1.""" + soup = BeautifulSoup("

Title

", "html.parser") + tag = soup.find("h1") + + result = webscrape.heading_level(tag) + + assert result == 1 + + def test_heading_level_h2(self): + """heading_level should return 2 for h2.""" + soup = BeautifulSoup("

Title

", "html.parser") + tag = soup.find("h2") + + result = webscrape.heading_level(tag) + + assert result == 2 + + def test_heading_level_h6(self): + """heading_level should return 6 for h6.""" + soup = BeautifulSoup("
Title
", "html.parser") + tag = soup.find("h6") + + result = webscrape.heading_level(tag) + + assert result == 6 + + +class TestGroupBySections: + """Tests for the group_by_sections function.""" + + def test_group_by_sections_extracts_sections(self): + """group_by_sections should extract section content.""" + html = """ + +
+

Section Title

+

Paragraph 1

+

Paragraph 2

+
+ + """ + soup = BeautifulSoup(html, "html.parser") + + result = webscrape.group_by_sections(soup) + + assert len(result) == 1 + assert result[0]["title"] == "Section Title" + assert "Paragraph 1" in result[0]["content"] + + def test_group_by_sections_handles_articles(self): + """group_by_sections should handle article tags.""" + html = """ + +
+

Article Title

+

Article content

+
+ + """ + soup = BeautifulSoup(html, "html.parser") + + result = webscrape.group_by_sections(soup) + + assert len(result) == 1 + assert result[0]["title"] == "Article Title" + + def test_group_by_sections_no_sections(self): + """group_by_sections should return empty list when no sections.""" + html = "

Plain content

" + soup = BeautifulSoup(html, "html.parser") + + result = webscrape.group_by_sections(soup) + + assert not result + + +class TestTableToMarkdown: + """Tests for the table_to_markdown function.""" + + def test_table_to_markdown_basic_table(self): + """table_to_markdown should convert table to markdown.""" + html = """ + + + +
Header 1Header 2
Cell 1Cell 2
+ """ + soup = BeautifulSoup(html, "html.parser") + table = soup.find("table") + + result = webscrape.table_to_markdown(table) + + assert "| Header 1 | Header 2 |" in result + assert "| --- | --- |" in result + assert "| Cell 1 | Cell 2 |" in result + + def test_table_to_markdown_empty_table(self): + """table_to_markdown should handle empty table.""" + html = "
" + soup = BeautifulSoup(html, "html.parser") + table = soup.find("table") + + result = webscrape.table_to_markdown(table) + + assert result == "" + + +class TestGroupByHeadings: + """Tests for the group_by_headings function.""" + + def test_group_by_headings_extracts_sections(self): + """group_by_headings should group content by heading.""" + html = """ + +

Section 1

+

Content 1

+

Section 2

+

Content 2

+ + """ + soup = BeautifulSoup(html, "html.parser") + + result = webscrape.group_by_headings(soup) + + assert len(result) == 2 + assert result[0]["title"] == "Section 1" + assert result[1]["title"] == "Section 2" + + def test_group_by_headings_handles_lists(self): + """group_by_headings should include list items.""" + html = """ + +

List Section

+
    +
  • Item 1
  • +
  • Item 2
  • +
+ + """ + soup = BeautifulSoup(html, "html.parser") + + result = webscrape.group_by_headings(soup) + + assert len(result) == 1 + assert "Item 1" in result[0]["content"] + + def test_group_by_headings_respects_hierarchy(self): + """group_by_headings should stop at same or higher level heading.""" + html = """ + +

Parent

+

Parent content

+

Child

+

Child content

+

Sibling

+

Sibling content

+ + """ + soup = BeautifulSoup(html, "html.parser") + + result = webscrape.group_by_headings(soup) + + # h2 sections should not include content from sibling h2 + parent_section = next(s for s in result if s["title"] == "Parent") + assert "Sibling content" not in parent_section["content"] + + +class TestSectionsToMarkdown: + """Tests for the sections_to_markdown function.""" + + def test_sections_to_markdown_basic(self): + """sections_to_markdown should convert sections to markdown.""" + sections = [ + {"title": "Section 1", "level": 1, "paragraphs": ["Para 1"]}, + {"title": "Section 2", "level": 2, "paragraphs": ["Para 2"]}, + ] + + result = webscrape.sections_to_markdown(sections) + + assert "# Section 1" in result + assert "## Section 2" in result + + def test_sections_to_markdown_empty_list(self): + """sections_to_markdown should handle empty list.""" + result = webscrape.sections_to_markdown([]) + + assert result == "" + + +class TestSlugify: + """Tests for the slugify function.""" + + def test_slugify_basic(self): + """slugify should convert text to URL-safe slug.""" + result = webscrape.slugify("Hello World") + + assert result == "hello-world" + + def test_slugify_special_characters(self): + """slugify should remove special characters.""" + result = webscrape.slugify("Hello! World?") + + assert result == "hello-world" + + def test_slugify_max_length(self): + """slugify should respect max length.""" + long_text = "a" * 100 + result = webscrape.slugify(long_text, max_len=10) + + assert len(result) == 10 + + def test_slugify_empty_string(self): + """slugify should return 'page' for empty result.""" + result = webscrape.slugify("!!!") + + assert result == "page" + + def test_slugify_multiple_spaces(self): + """slugify should collapse multiple spaces/dashes.""" + result = webscrape.slugify("Hello World") + + assert result == "hello-world" + + +class TestFetchAndExtractParagraphs: + """Tests for the fetch_and_extract_paragraphs function.""" + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession") + async def test_fetch_and_extract_paragraphs_success(self, mock_session_class): + """fetch_and_extract_paragraphs should extract paragraphs from URL.""" + html = "

Paragraph 1

Paragraph 2

" + + mock_response = AsyncMock() + mock_response.text = AsyncMock(return_value=html) + create_mock_aiohttp_session(mock_session_class, mock_response) + + result = await webscrape.fetch_and_extract_paragraphs("https://example.com") + + assert len(result) == 2 + assert "Paragraph 1" in result + assert "Paragraph 2" in result + + +class TestFetchAndExtractSections: + """Tests for the fetch_and_extract_sections function.""" + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession") + async def test_fetch_and_extract_sections_with_sections(self, mock_session_class): + """fetch_and_extract_sections should extract sections from URL.""" + html = """ + +

Title

Content

+ + """ + + mock_response = AsyncMock() + mock_response.text = AsyncMock(return_value=html) + create_mock_aiohttp_session(mock_session_class, mock_response) + + result = await webscrape.fetch_and_extract_sections("https://example.com") + + assert len(result) == 1 + assert result[0]["title"] == "Title" + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession") + async def test_fetch_and_extract_sections_falls_back_to_headings(self, mock_session_class): + """fetch_and_extract_sections should fall back to headings.""" + html = """ + +

Heading

+

Content

+ + """ + + mock_response = AsyncMock() + mock_response.text = AsyncMock(return_value=html) + create_mock_aiohttp_session(mock_session_class, mock_response) + + result = await webscrape.fetch_and_extract_sections("https://example.com") + + assert len(result) == 1 + assert result[0]["title"] == "Heading" + + +class TestBadChunks: + """Tests for the BAD_CHUNKS constant.""" + + def test_bad_chunks_contains_common_elements(self): + """BAD_CHUNKS should contain common unwanted elements.""" + assert "nav" in webscrape.BAD_CHUNKS + assert "header" in webscrape.BAD_CHUNKS + assert "footer" in webscrape.BAD_CHUNKS + assert "ads" in webscrape.BAD_CHUNKS + assert "comment" in webscrape.BAD_CHUNKS + + def test_bad_chunks_is_list(self): + """BAD_CHUNKS should be a list.""" + assert isinstance(webscrape.BAD_CHUNKS, list) diff --git a/test/unit/server/api/v1/__init__.py b/test/unit/server/api/v1/__init__.py new file mode 100644 index 00000000..a6ad55f3 --- /dev/null +++ b/test/unit/server/api/v1/__init__.py @@ -0,0 +1 @@ +# v1 API unit test package diff --git a/test/unit/server/api/v1/test_v1_chat.py b/test/unit/server/api/v1/test_v1_chat.py new file mode 100644 index 00000000..9c0b9b75 --- /dev/null +++ b/test/unit/server/api/v1/test_v1_chat.py @@ -0,0 +1,258 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/v1/chat.py +Tests for chat completion endpoints. +""" + +from unittest.mock import patch, MagicMock +import pytest +from fastapi.responses import StreamingResponse + +from server.api.v1 import chat + + +class TestChatPost: + """Tests for the chat_post endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.chat.chat.completion_generator") + async def test_chat_post_returns_last_message(self, mock_generator, make_chat_request): + """chat_post should return the final completion message.""" + request = make_chat_request(content="Hello") + mock_response = {"choices": [{"message": {"content": "Hi there!"}}]} + + async def mock_gen(): + yield mock_response + + mock_generator.return_value = mock_gen() + + result = await chat.chat_post(request=request, client="test_client") + + assert result == mock_response + mock_generator.assert_called_once_with("test_client", request, "completions") + + @pytest.mark.asyncio + @patch("server.api.v1.chat.chat.completion_generator") + async def test_chat_post_iterates_through_all_chunks(self, mock_generator, make_chat_request): + """chat_post should iterate through all chunks and return last.""" + request = make_chat_request(content="Hello") + + async def mock_gen(): + yield "chunk1" + yield "chunk2" + yield {"final": "response"} + + mock_generator.return_value = mock_gen() + + result = await chat.chat_post(request=request, client="test_client") + + assert result == {"final": "response"} + + @pytest.mark.asyncio + @patch("server.api.v1.chat.chat.completion_generator") + async def test_chat_post_uses_default_client(self, mock_generator, make_chat_request): + """chat_post should use 'server' as default client.""" + request = make_chat_request() + + async def mock_gen(): + yield {"response": "data"} + + mock_generator.return_value = mock_gen() + + await chat.chat_post(request=request, client="server") + + mock_generator.assert_called_once_with("server", request, "completions") + + +class TestChatStream: + """Tests for the chat_stream endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.chat.chat.completion_generator") + async def test_chat_stream_returns_streaming_response(self, mock_generator, make_chat_request): + """chat_stream should return a StreamingResponse.""" + request = make_chat_request(content="Hello") + + async def mock_gen(): + yield b"chunk1" + yield b"chunk2" + + mock_generator.return_value = mock_gen() + + result = await chat.chat_stream(request=request, client="test_client") + + assert isinstance(result, StreamingResponse) + assert result.media_type == "application/octet-stream" + + @pytest.mark.asyncio + @patch("server.api.v1.chat.chat.completion_generator") + async def test_chat_stream_calls_generator_with_streams_mode(self, mock_generator, make_chat_request): + """chat_stream should call generator with 'streams' mode.""" + request = make_chat_request() + + async def mock_gen(): + yield b"data" + + mock_generator.return_value = mock_gen() + + await chat.chat_stream(request=request, client="test_client") + + mock_generator.assert_called_once_with("test_client", request, "streams") + + +class TestChatHistoryClean: + """Tests for the chat_history_clean endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.chat.chatbot.chatbot_graph") + async def test_chat_history_clean_success(self, mock_graph): + """chat_history_clean should clear history and return confirmation.""" + mock_graph.update_state = MagicMock(return_value=None) + + result = await chat.chat_history_clean(client="test_client") + + assert len(result) == 1 + assert "forgotten" in result[0].content + assert result[0].role == "system" + + @pytest.mark.asyncio + @patch("server.api.v1.chat.chatbot.chatbot_graph") + async def test_chat_history_clean_updates_state_correctly(self, mock_graph): + """chat_history_clean should update state with correct values.""" + mock_graph.update_state = MagicMock(return_value=None) + + await chat.chat_history_clean(client="test_client") + + call_args = mock_graph.update_state.call_args + values = call_args[1]["values"] + + assert "messages" in values + assert values["cleaned_messages"] == [] + assert values["context_input"] == "" + assert values["documents"] == {} + assert values["final_response"] == {} + + @pytest.mark.asyncio + @patch("server.api.v1.chat.chatbot.chatbot_graph") + async def test_chat_history_clean_handles_key_error(self, mock_graph): + """chat_history_clean should handle KeyError gracefully.""" + mock_graph.update_state = MagicMock(side_effect=KeyError("thread not found")) + + result = await chat.chat_history_clean(client="nonexistent_client") + + assert len(result) == 1 + assert "no history" in result[0].content + assert result[0].role == "system" + + @pytest.mark.asyncio + @patch("server.api.v1.chat.chatbot.chatbot_graph") + async def test_chat_history_clean_uses_correct_thread_id(self, mock_graph): + """chat_history_clean should use client as thread_id.""" + mock_graph.update_state = MagicMock(return_value=None) + + await chat.chat_history_clean(client="my_client_id") + + call_args = mock_graph.update_state.call_args + # config is passed as keyword argument, RunnableConfig is dict-like + config = call_args.kwargs["config"] + + assert config["configurable"]["thread_id"] == "my_client_id" + + +class TestChatHistoryReturn: + """Tests for the chat_history_return endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.chat.chatbot.chatbot_graph") + @patch("server.api.v1.chat.convert_to_openai_messages") + async def test_chat_history_return_success(self, mock_convert, mock_graph): + """chat_history_return should return chat messages.""" + mock_messages = [ + MagicMock(content="Hello", role="user"), + MagicMock(content="Hi there", role="assistant"), + ] + mock_state = MagicMock() + mock_state.values = {"messages": mock_messages} + mock_graph.get_state = MagicMock(return_value=mock_state) + mock_convert.return_value = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + ] + + result = await chat.chat_history_return(client="test_client") + + assert len(result) == 2 + mock_convert.assert_called_once_with(mock_messages) + + @pytest.mark.asyncio + @patch("server.api.v1.chat.chatbot.chatbot_graph") + async def test_chat_history_return_handles_key_error(self, mock_graph): + """chat_history_return should handle KeyError gracefully.""" + mock_graph.get_state = MagicMock(side_effect=KeyError("thread not found")) + + result = await chat.chat_history_return(client="nonexistent_client") + + assert len(result) == 1 + assert "no history" in result[0].content + assert result[0].role == "system" + + @pytest.mark.asyncio + @patch("server.api.v1.chat.chatbot.chatbot_graph") + async def test_chat_history_return_uses_correct_thread_id(self, mock_graph): + """chat_history_return should use client as thread_id.""" + mock_state = MagicMock() + mock_state.values = {"messages": []} + mock_graph.get_state = MagicMock(return_value=mock_state) + + with patch("server.api.v1.chat.convert_to_openai_messages", return_value=[]): + await chat.chat_history_return(client="my_client_id") + + call_args = mock_graph.get_state.call_args + # config is passed as keyword argument, RunnableConfig is dict-like + config = call_args.kwargs["config"] + + assert config["configurable"]["thread_id"] == "my_client_id" + + @pytest.mark.asyncio + @patch("server.api.v1.chat.chatbot.chatbot_graph") + @patch("server.api.v1.chat.convert_to_openai_messages") + async def test_chat_history_return_empty_history(self, mock_convert, mock_graph): + """chat_history_return should handle empty history.""" + mock_state = MagicMock() + mock_state.values = {"messages": []} + mock_graph.get_state = MagicMock(return_value=mock_state) + mock_convert.return_value = [] + + result = await chat.chat_history_return(client="test_client") + + assert result == [] + + +class TestRouterConfiguration: + """Tests for router configuration.""" + + def test_auth_router_exists(self): + """The auth router should be defined.""" + assert hasattr(chat, "auth") + + def test_auth_router_has_routes(self): + """The auth router should have registered routes.""" + routes = [route.path for route in chat.auth.routes] + + assert "/completions" in routes + assert "/streams" in routes + assert "/history" in routes + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured.""" + assert hasattr(chat, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert chat.logger.name == "endpoints.v1.chat" diff --git a/test/unit/server/api/v1/test_v1_databases.py b/test/unit/server/api/v1/test_v1_databases.py new file mode 100644 index 00000000..7d052cd9 --- /dev/null +++ b/test/unit/server/api/v1/test_v1_databases.py @@ -0,0 +1,184 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/v1/databases.py +Tests for database configuration endpoints. + +Note: These tests mock utils_databases functions to test endpoint logic +(HTTP responses, error handling). The underlying database operations +are tested with real Oracle database in test_utils_databases.py. +""" + +from unittest.mock import patch, MagicMock +import pytest +from fastapi import HTTPException + +from server.api.v1 import databases + + +class TestDatabasesList: + """Tests for the databases_list endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.databases.utils_databases.get_databases") + async def test_databases_list_returns_all_databases(self, mock_get_databases, make_database): + """databases_list should return all configured databases.""" + db_list = [ + make_database(name="DB1"), + make_database(name="DB2"), + ] + mock_get_databases.return_value = db_list + + result = await databases.databases_list() + + assert result == db_list + mock_get_databases.assert_called_once_with(validate=False) + + @pytest.mark.asyncio + @patch("server.api.v1.databases.utils_databases.get_databases") + async def test_databases_list_returns_empty_list(self, mock_get_databases): + """databases_list should return empty list when no databases.""" + mock_get_databases.return_value = [] + + result = await databases.databases_list() + + assert result == [] + + @pytest.mark.asyncio + @patch("server.api.v1.databases.utils_databases.get_databases") + async def test_databases_list_raises_404_on_value_error(self, mock_get_databases): + """databases_list should raise 404 when ValueError occurs.""" + mock_get_databases.side_effect = ValueError("No databases found") + + with pytest.raises(HTTPException) as exc_info: + await databases.databases_list() + + assert exc_info.value.status_code == 404 + + +class TestDatabasesGet: + """Tests for the databases_get endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.databases.utils_databases.get_databases") + async def test_databases_get_returns_single_database(self, mock_get_databases, make_database): + """databases_get should return a single database by name.""" + database = make_database(name="TEST_DB") + mock_get_databases.return_value = database + + result = await databases.databases_get(name="TEST_DB") + + assert result == database + mock_get_databases.assert_called_once_with(db_name="TEST_DB", validate=True) + + @pytest.mark.asyncio + @patch("server.api.v1.databases.utils_databases.get_databases") + async def test_databases_get_raises_404_when_not_found(self, mock_get_databases): + """databases_get should raise 404 when database not found.""" + mock_get_databases.side_effect = ValueError("Database not found") + + with pytest.raises(HTTPException) as exc_info: + await databases.databases_get(name="NONEXISTENT") + + assert exc_info.value.status_code == 404 + + +class TestDatabasesUpdate: + """Tests for the databases_update endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.databases.utils_databases.get_databases") + @patch("server.api.v1.databases.utils_databases.connect") + @patch("server.api.v1.databases.utils_databases.disconnect") + async def test_databases_update_returns_updated_database( + self, _mock_disconnect, mock_connect, mock_get_databases, make_database, make_database_auth + ): + """databases_update should return the updated database.""" + existing_db = make_database(name="TEST_DB", user="old_user") + # First call returns the single db, second call returns list for cleanup + mock_get_databases.side_effect = [existing_db, [existing_db]] + mock_connect.return_value = MagicMock() + + payload = make_database_auth(user="new_user", password="new_pass", dsn="localhost:1521/TEST") + + result = await databases.databases_update(name="TEST_DB", payload=payload) + + assert result.user == "new_user" + assert result.connected is True + + @pytest.mark.asyncio + @patch("server.api.v1.databases.utils_databases.get_databases") + async def test_databases_update_raises_404_when_not_found(self, mock_get_databases, make_database_auth): + """databases_update should raise 404 when database not found.""" + mock_get_databases.side_effect = ValueError("Database not found") + + payload = make_database_auth() + + with pytest.raises(HTTPException) as exc_info: + await databases.databases_update(name="NONEXISTENT", payload=payload) + + assert exc_info.value.status_code == 404 + + @pytest.mark.asyncio + @patch("server.api.v1.databases.utils_databases.get_databases") + @patch("server.api.v1.databases.utils_databases.connect") + async def test_databases_update_raises_400_on_value_error( + self, mock_connect, mock_get_databases, make_database, make_database_auth + ): + """databases_update should raise 400 on ValueError during connect.""" + existing_db = make_database(name="TEST_DB") + mock_get_databases.return_value = existing_db + mock_connect.side_effect = ValueError("Invalid parameters") + + payload = make_database_auth() + + with pytest.raises(HTTPException) as exc_info: + await databases.databases_update(name="TEST_DB", payload=payload) + + assert exc_info.value.status_code == 400 + + @pytest.mark.asyncio + @patch("server.api.v1.databases.utils_databases.get_databases") + @patch("server.api.v1.databases.utils_databases.connect") + async def test_databases_update_raises_401_on_permission_error( + self, mock_connect, mock_get_databases, make_database, make_database_auth + ): + """databases_update should raise 401 on PermissionError during connect.""" + existing_db = make_database(name="TEST_DB") + mock_get_databases.return_value = existing_db + mock_connect.side_effect = PermissionError("Access denied") + + payload = make_database_auth() + + with pytest.raises(HTTPException) as exc_info: + await databases.databases_update(name="TEST_DB", payload=payload) + + assert exc_info.value.status_code == 401 + + +class TestRouterConfiguration: + """Tests for router configuration.""" + + def test_auth_router_exists(self): + """The auth router should be defined.""" + assert hasattr(databases, "auth") + + def test_auth_router_has_routes(self): + """The auth router should have registered routes.""" + routes = [route.path for route in databases.auth.routes] + + assert "" in routes + assert "/{name}" in routes + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured.""" + assert hasattr(databases, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert databases.logger.name == "endpoints.v1.databases" diff --git a/test/unit/server/api/v1/test_v1_embed.py b/test/unit/server/api/v1/test_v1_embed.py new file mode 100644 index 00000000..17a442c9 --- /dev/null +++ b/test/unit/server/api/v1/test_v1_embed.py @@ -0,0 +1,553 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/v1/embed.py +Tests for document embedding and vector store endpoints. +""" +# pylint: disable=protected-access +# pylint: disable=redefined-outer-name +# Pytest fixtures use parameter injection where fixture names match parameters + +from io import BytesIO +from pathlib import Path +from test.unit.server.api.conftest import create_mock_aiohttp_session +from unittest.mock import patch, MagicMock, AsyncMock +import json + +import pytest +from fastapi import HTTPException, UploadFile +from pydantic import HttpUrl + +from common.schema import DatabaseVectorStorage, VectorStoreRefreshRequest +from server.api.v1 import embed +from server.api.utils.databases import DbException + + +@pytest.fixture +def split_embed_mocks(): + """Fixture providing bundled mocks for split_embed tests.""" + with patch("server.api.v1.embed.utils_oci.get") as mock_oci_get, \ + patch("server.api.v1.embed.utils_embed.get_temp_directory") as mock_get_temp, \ + patch("server.api.v1.embed.utils_embed.load_and_split_documents") as mock_load_split, \ + patch("server.api.v1.embed.utils_models.get_client_embed") as mock_get_embed, \ + patch("server.api.v1.embed.functions.get_vs_table") as mock_get_vs_table, \ + patch("server.api.v1.embed.utils_embed.populate_vs") as mock_populate, \ + patch("server.api.v1.embed.utils_databases.get_client_database") as mock_get_db, \ + patch("shutil.rmtree") as mock_rmtree: + yield { + "oci_get": mock_oci_get, + "get_temp": mock_get_temp, + "load_split": mock_load_split, + "get_embed": mock_get_embed, + "get_vs_table": mock_get_vs_table, + "populate": mock_populate, + "get_db": mock_get_db, + "rmtree": mock_rmtree, + } + + +class TestExtractProviderErrorMessage: + """Tests for the _extract_provider_error_message helper function.""" + + def test_exception_with_message(self): + """Test extraction of exception with message""" + error = Exception("Something went wrong") + result = embed._extract_provider_error_message(error) + assert result == "Something went wrong" + + def test_exception_without_message(self): + """Test extraction of exception without message""" + error = ValueError() + result = embed._extract_provider_error_message(error) + assert result == "Error: ValueError" + + def test_openai_quota_exceeded(self): + """Test extraction of OpenAI quota exceeded error message""" + error_msg = ( + "Error code: 429 - {'error': {'message': 'You exceeded your current quota, " + "please check your plan and billing details.', 'type': 'insufficient_quota'}}" + ) + error = Exception(error_msg) + result = embed._extract_provider_error_message(error) + assert result == error_msg + + def test_openai_rate_limit(self): + """Test extraction of OpenAI rate limit error message""" + error_msg = "Rate limit exceeded. Please try again later." + error = Exception(error_msg) + result = embed._extract_provider_error_message(error) + assert result == error_msg + + def test_complex_error_message(self): + """Test extraction of complex multi-line error message""" + error_msg = "Connection failed\nTimeout: 30s\nHost: api.example.com" + error = Exception(error_msg) + result = embed._extract_provider_error_message(error) + assert result == error_msg + + @pytest.mark.parametrize( + "error_message", + [ + "OpenAI API key is invalid", + "Cohere API error occurred", + "OCI service error", + "Database connection failed", + "Rate limit exceeded for model xyz", + ], + ) + def test_various_error_messages(self, error_message): + """Test that various error messages are passed through correctly""" + error = Exception(error_message) + result = embed._extract_provider_error_message(error) + assert result == error_message + + +class TestEmbedDropVs: + """Tests for the embed_drop_vs endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_databases.get_client_database") + @patch("server.api.v1.embed.utils_databases.connect") + @patch("server.api.v1.embed.utils_databases.drop_vs") + async def test_embed_drop_vs_success(self, mock_drop, mock_connect, mock_get_db, make_database): + """embed_drop_vs should drop vector store and return success.""" + mock_db = make_database() + mock_get_db.return_value = mock_db + mock_connect.return_value = MagicMock() + mock_drop.return_value = None + + result = await embed.embed_drop_vs(vs="VS_TEST", client="test_client") + + assert result.status_code == 200 + mock_drop.assert_called_once() + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_databases.get_client_database") + @patch("server.api.v1.embed.utils_databases.connect") + @patch("server.api.v1.embed.utils_databases.drop_vs") + async def test_embed_drop_vs_raises_400_on_db_exception(self, mock_drop, mock_connect, mock_get_db, make_database): + """embed_drop_vs should raise 400 on DbException.""" + mock_db = make_database() + mock_get_db.return_value = mock_db + mock_connect.return_value = MagicMock() + mock_drop.side_effect = DbException(status_code=400, detail="Table not found") + + with pytest.raises(HTTPException) as exc_info: + await embed.embed_drop_vs(vs="VS_NONEXISTENT", client="test_client") + + assert exc_info.value.status_code == 400 + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_databases.get_client_database") + @patch("server.api.v1.embed.utils_databases.connect") + @patch("server.api.v1.embed.utils_databases.drop_vs") + async def test_embed_drop_vs_response_contains_vs_name(self, mock_drop, mock_connect, mock_get_db, make_database): + """embed_drop_vs response should contain vector store name.""" + mock_db = make_database() + mock_get_db.return_value = mock_db + mock_connect.return_value = MagicMock() + mock_drop.return_value = None + + result = await embed.embed_drop_vs(vs="VS_MY_STORE", client="test_client") + + body = json.loads(result.body) + assert "VS_MY_STORE" in body["message"] + + +class TestEmbedGetFiles: + """Tests for the embed_get_files endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_databases.get_client_database") + @patch("server.api.v1.embed.utils_embed.get_vector_store_files") + async def test_embed_get_files_success(self, mock_get_files, mock_get_db, make_database): + """embed_get_files should return file list.""" + mock_db = make_database() + mock_get_db.return_value = mock_db + mock_get_files.return_value = [ + {"filename": "file1.pdf", "chunks": 10}, + {"filename": "file2.txt", "chunks": 5}, + ] + + result = await embed.embed_get_files(vs="VS_TEST", client="test_client") + + assert result.status_code == 200 + mock_get_files.assert_called_once_with(mock_db, "VS_TEST") + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_databases.get_client_database") + @patch("server.api.v1.embed.utils_embed.get_vector_store_files") + async def test_embed_get_files_raises_400_on_exception(self, mock_get_files, mock_get_db, make_database): + """embed_get_files should raise 400 on exception.""" + mock_db = make_database() + mock_get_db.return_value = mock_db + mock_get_files.side_effect = Exception("Query failed") + + with pytest.raises(HTTPException) as exc_info: + await embed.embed_get_files(vs="VS_TEST", client="test_client") + + assert exc_info.value.status_code == 400 + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_databases.get_client_database") + @patch("server.api.v1.embed.utils_embed.get_vector_store_files") + async def test_embed_get_files_empty_list(self, mock_get_files, mock_get_db, make_database): + """embed_get_files should return empty list for empty vector store.""" + mock_db = make_database() + mock_get_db.return_value = mock_db + mock_get_files.return_value = [] + + result = await embed.embed_get_files(vs="VS_EMPTY", client="test_client") + + assert result.status_code == 200 + body = json.loads(result.body) + assert body == [] + + +class TestCommentVs: + """Tests for the comment_vs endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_databases.get_client_database") + @patch("server.api.v1.embed.utils_embed.update_vs_comment") + async def test_comment_vs_success(self, mock_update_comment, mock_get_db, make_database, make_vector_store): + """comment_vs should update vector store comment and return success.""" + mock_db = make_database() + mock_get_db.return_value = mock_db + mock_update_comment.return_value = None + + request = make_vector_store(vector_store="VS_TEST") + + result = await embed.comment_vs(request=request, client="test_client") + + assert result.status_code == 200 + body = json.loads(result.body) + assert "comment updated" in body["message"] + mock_update_comment.assert_called_once_with(vector_store=request, db_details=mock_db) + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_databases.get_client_database") + @patch("server.api.v1.embed.utils_embed.update_vs_comment") + async def test_comment_vs_calls_get_client_database( + self, mock_update_comment, mock_get_db, make_database, make_vector_store + ): + """comment_vs should call get_client_database with correct client.""" + mock_db = make_database() + mock_get_db.return_value = mock_db + mock_update_comment.return_value = None + + request = make_vector_store() + + await embed.comment_vs(request=request, client="my_client") + + mock_get_db.assert_called_once_with("my_client") + + +class TestStoreSqlFile: + """Tests for the store_sql_file endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_embed.get_temp_directory") + @patch("server.api.v1.embed.functions.run_sql_query") + async def test_store_sql_file_success(self, mock_run_sql, mock_get_temp, tmp_path): + """store_sql_file should execute SQL and return file path.""" + mock_get_temp.return_value = tmp_path + mock_run_sql.return_value = "result.csv" + + result = await embed.store_sql_file(request=["conn_str", "SELECT * FROM table"], client="test_client") + + assert result.status_code == 200 + body = json.loads(result.body) + assert "result.csv" in body + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_embed.get_temp_directory") + @patch("server.api.v1.embed.functions.run_sql_query") + async def test_store_sql_file_calls_run_sql_query(self, mock_run_sql, mock_get_temp, tmp_path): + """store_sql_file should call run_sql_query with correct params.""" + mock_get_temp.return_value = tmp_path + mock_run_sql.return_value = "output.csv" + + await embed.store_sql_file(request=["db_conn", "SELECT 1"], client="test_client") + + mock_run_sql.assert_called_once_with(db_conn="db_conn", query="SELECT 1", base_path=tmp_path) + + +class TestStoreWebFile: + """Tests for the store_web_file endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_embed.get_temp_directory") + @patch("server.api.v1.embed.web_parse.fetch_and_extract_sections") + @patch("server.api.v1.embed.web_parse.slugify") + @patch("aiohttp.ClientSession") + async def test_store_web_file_html_success( + self, mock_session_class, mock_slugify, mock_fetch_sections, mock_get_temp, tmp_path + ): + """store_web_file should fetch HTML and extract sections.""" + mock_get_temp.return_value = tmp_path + mock_slugify.return_value = "test-page" + mock_fetch_sections.return_value = [{"title": "Section 1", "content": "Content 1"}] + + mock_response = AsyncMock() + mock_response.headers = {"Content-Type": "text/html"} + mock_response.read = AsyncMock(return_value=b"") + create_mock_aiohttp_session(mock_session_class, mock_response) + + result = await embed.store_web_file(request=[HttpUrl("https://example.com/page")], client="test_client") + + assert result.status_code == 200 + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_embed.get_temp_directory") + @patch("aiohttp.ClientSession") + async def test_store_web_file_pdf_success(self, mock_session_class, mock_get_temp, tmp_path): + """store_web_file should download PDF files.""" + mock_get_temp.return_value = tmp_path + + mock_response = AsyncMock() + mock_response.headers = {"Content-Type": "application/pdf"} + mock_response.read = AsyncMock(return_value=b"%PDF-1.4") + create_mock_aiohttp_session(mock_session_class, mock_response) + + result = await embed.store_web_file(request=[HttpUrl("https://example.com/doc.pdf")], client="test_client") + + assert result.status_code == 200 + + +class TestStoreLocalFile: + """Tests for the store_local_file endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_embed.get_temp_directory") + async def test_store_local_file_success(self, mock_get_temp, tmp_path): + """store_local_file should save uploaded files.""" + mock_get_temp.return_value = tmp_path + + mock_file = UploadFile(file=BytesIO(b"Test content"), filename="test.txt") + + result = await embed.store_local_file(files=[mock_file], client="test_client") + + assert result.status_code == 200 + body = json.loads(result.body) + assert "test.txt" in body + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_embed.get_temp_directory") + async def test_store_local_file_creates_metadata(self, mock_get_temp, tmp_path): + """store_local_file should create metadata file.""" + mock_get_temp.return_value = tmp_path + + mock_file = UploadFile(file=BytesIO(b"Test content"), filename="test.txt") + + await embed.store_local_file(files=[mock_file], client="test_client") + + metadata_file = tmp_path / ".file_metadata.json" + assert metadata_file.exists() + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_embed.get_temp_directory") + async def test_store_local_file_multiple_files(self, mock_get_temp, tmp_path): + """store_local_file should handle multiple files.""" + mock_get_temp.return_value = tmp_path + + files = [ + UploadFile(file=BytesIO(b"Content 1"), filename="file1.txt"), + UploadFile(file=BytesIO(b"Content 2"), filename="file2.txt"), + ] + + result = await embed.store_local_file(files=files, client="test_client") + + body = json.loads(result.body) + assert len(body) == 2 + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_embed.get_temp_directory") + async def test_store_local_file_metadata_excludes_metadata_file(self, mock_get_temp, tmp_path): + """store_local_file should not include metadata file in response.""" + mock_get_temp.return_value = tmp_path + + mock_file = UploadFile(file=BytesIO(b"Content"), filename="test.txt") + + result = await embed.store_local_file(files=[mock_file], client="test_client") + + body = json.loads(result.body) + assert ".file_metadata.json" not in body + + +class TestSplitEmbed: + """Tests for the split_embed endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_oci.get") + @patch("server.api.v1.embed.utils_embed.get_temp_directory") + async def test_split_embed_raises_404_when_no_files(self, mock_get_temp, mock_oci_get, tmp_path, make_oci_config): + """split_embed should raise 404 when no files found.""" + mock_oci_get.return_value = make_oci_config() + mock_get_temp.return_value = tmp_path # Empty directory + + request = DatabaseVectorStorage(model="text-embedding-3", chunk_size=1000, chunk_overlap=200) + + with pytest.raises(HTTPException) as exc_info: + await embed.split_embed(request=request, rate_limit=0, client="test_client") + + assert exc_info.value.status_code == 404 + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_oci.get") + @patch("server.api.v1.embed.utils_embed.get_temp_directory") + async def test_split_embed_raises_404_when_folder_not_found(self, mock_get_temp, mock_oci_get, make_oci_config): + """split_embed should raise 404 when folder not found.""" + mock_oci_get.return_value = make_oci_config() + mock_get_temp.return_value = Path("/nonexistent/path") + + request = DatabaseVectorStorage(model="text-embedding-3", chunk_size=1000, chunk_overlap=200) + + with pytest.raises(HTTPException) as exc_info: + await embed.split_embed(request=request, rate_limit=0, client="test_client") + + assert exc_info.value.status_code == 404 + + @pytest.mark.asyncio + async def test_split_embed_success( + self, split_embed_mocks, tmp_path, make_oci_config, make_database + ): + """split_embed should process files and populate vector store.""" + mocks = split_embed_mocks + mocks["oci_get"].return_value = make_oci_config() + mocks["get_temp"].return_value = tmp_path + mocks["load_split"].return_value = (["doc1", "doc2"], None) + mocks["get_embed"].return_value = MagicMock() + mocks["get_vs_table"].return_value = ("VS_TEST", "test_alias") + mocks["populate"].return_value = None + mocks["get_db"].return_value = make_database() + + # Create a test file + (tmp_path / "test.txt").write_text("Test content") + + request = DatabaseVectorStorage(model="text-embedding-3", chunk_size=1000, chunk_overlap=200) + + result = await embed.split_embed(request=request, rate_limit=0, client="test_client") + + assert result.status_code == 200 + mocks["populate"].assert_called_once() + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_oci.get") + @patch("server.api.v1.embed.utils_embed.get_temp_directory") + @patch("server.api.v1.embed.utils_embed.load_and_split_documents") + @patch("shutil.rmtree") + async def test_split_embed_raises_500_on_value_error( + self, _mock_rmtree, mock_load_split, mock_get_temp, mock_oci_get, tmp_path, make_oci_config + ): + """split_embed should raise 500 on ValueError during processing.""" + mock_oci_get.return_value = make_oci_config() + mock_get_temp.return_value = tmp_path + mock_load_split.side_effect = ValueError("Invalid document format") + + # Create a test file + (tmp_path / "test.txt").write_text("Test content") + + request = DatabaseVectorStorage(model="text-embedding-3", chunk_size=1000, chunk_overlap=200) + + with pytest.raises(HTTPException) as exc_info: + await embed.split_embed(request=request, rate_limit=0, client="test_client") + + assert exc_info.value.status_code == 500 + + +class TestRefreshVectorStore: + """Tests for the refresh_vector_store endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_oci.get") + @patch("server.api.v1.embed.utils_databases.get_client_database") + @patch("server.api.v1.embed.utils_embed.get_vector_store_by_alias") + @patch("server.api.v1.embed.utils_oci.get_bucket_objects_with_metadata") + async def test_refresh_vector_store_no_files( + self, + mock_get_objects, + mock_get_vs, + mock_get_db, + mock_oci_get, + make_oci_config, + make_database, + make_vector_store, + ): + """refresh_vector_store should return success when no files.""" + mock_oci_get.return_value = make_oci_config() + mock_get_db.return_value = make_database() + mock_get_vs.return_value = make_vector_store() + mock_get_objects.return_value = [] + + request = VectorStoreRefreshRequest(vector_store_alias="test_alias", bucket_name="test-bucket") + + result = await embed.refresh_vector_store(request=request, client="test_client") + + assert result.status_code == 200 + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_oci.get") + async def test_refresh_vector_store_raises_400_on_value_error(self, mock_oci_get): + """refresh_vector_store should raise 400 on ValueError.""" + mock_oci_get.side_effect = ValueError("Invalid config") + + request = VectorStoreRefreshRequest(vector_store_alias="test_alias", bucket_name="test-bucket") + + with pytest.raises(HTTPException) as exc_info: + await embed.refresh_vector_store(request=request, client="test_client") + + assert exc_info.value.status_code == 400 + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_oci.get") + @patch("server.api.v1.embed.utils_databases.get_client_database") + @patch("server.api.v1.embed.utils_embed.get_vector_store_by_alias") + async def test_refresh_vector_store_raises_500_on_db_exception( + self, mock_get_vs, mock_get_db, mock_oci_get, make_oci_config, make_database + ): + """refresh_vector_store should raise 500 on DbException.""" + mock_oci_get.return_value = make_oci_config() + mock_get_db.return_value = make_database() + mock_get_vs.side_effect = DbException(status_code=500, detail="Database error") + + request = VectorStoreRefreshRequest(vector_store_alias="test_alias", bucket_name="test-bucket") + + with pytest.raises(HTTPException) as exc_info: + await embed.refresh_vector_store(request=request, client="test_client") + + assert exc_info.value.status_code == 500 + + +class TestRouterConfiguration: + """Tests for router configuration.""" + + def test_auth_router_exists(self): + """The auth router should be defined.""" + assert hasattr(embed, "auth") + + def test_auth_router_has_routes(self): + """The auth router should have registered routes.""" + routes = [route.path for route in embed.auth.routes] + + assert "/{vs}" in routes + assert "/{vs}/files" in routes + assert "/comment" in routes + assert "/sql/store" in routes + assert "/web/store" in routes + assert "/local/store" in routes + assert "/" in routes + assert "/refresh" in routes + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured.""" + assert hasattr(embed, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert embed.logger.name == "api.v1.embed" diff --git a/test/unit/server/api/v1/test_v1_mcp.py b/test/unit/server/api/v1/test_v1_mcp.py new file mode 100644 index 00000000..dc10c82b --- /dev/null +++ b/test/unit/server/api/v1/test_v1_mcp.py @@ -0,0 +1,169 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/v1/mcp.py +Tests for MCP (Model Context Protocol) endpoints. +""" + +# pylint: disable=too-few-public-methods + +from unittest.mock import patch, MagicMock, AsyncMock +import pytest + +from server.api.v1 import mcp + + +class TestGetMcp: + """Tests for the get_mcp dependency function.""" + + def test_get_mcp_returns_fastmcp_app(self): + """get_mcp should return the FastMCP app from request state.""" + mock_request = MagicMock() + mock_fastmcp = MagicMock() + mock_request.app.state.fastmcp_app = mock_fastmcp + + result = mcp.get_mcp(mock_request) + + assert result == mock_fastmcp + + +class TestGetClient: + """Tests for the get_client endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.mcp.utils_mcp.get_client") + async def test_get_client_returns_config(self, mock_get_client): + """get_client should return MCP client configuration.""" + expected_config = { + "mcpServers": { + "optimizer": { + "type": "streamableHttp", + "transport": "streamable_http", + "url": "http://127.0.0.1:8000/mcp/", + "headers": {"Authorization": "Bearer test-key"}, + } + } + } + mock_get_client.return_value = expected_config + + result = await mcp.get_client(server="http://127.0.0.1", port=8000) + + assert result == expected_config + mock_get_client.assert_called_once_with("http://127.0.0.1", 8000) + + @pytest.mark.asyncio + @patch("server.api.v1.mcp.utils_mcp.get_client") + async def test_get_client_with_default_params(self, mock_get_client): + """get_client should use default parameters.""" + mock_get_client.return_value = {} + + await mcp.get_client() + + mock_get_client.assert_called_once_with(None, None) + + +class TestGetTools: + """Tests for the get_tools endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.mcp.Client") + async def test_get_tools_returns_tool_list(self, mock_client_class, mock_fastmcp): + """get_tools should return list of MCP tools.""" + mock_tool1 = MagicMock() + mock_tool1.model_dump.return_value = {"name": "optimizer_tool1"} + mock_tool2 = MagicMock() + mock_tool2.model_dump.return_value = {"name": "optimizer_tool2"} + + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.list_tools = AsyncMock(return_value=[mock_tool1, mock_tool2]) + mock_client.close = AsyncMock() + mock_client_class.return_value = mock_client + + result = await mcp.get_tools(mcp_engine=mock_fastmcp) + + assert len(result) == 2 + mock_client.close.assert_called_once() + + @pytest.mark.asyncio + @patch("server.api.v1.mcp.Client") + async def test_get_tools_returns_empty_list(self, mock_client_class, mock_fastmcp): + """get_tools should return empty list when no tools.""" + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.list_tools = AsyncMock(return_value=[]) + mock_client.close = AsyncMock() + mock_client_class.return_value = mock_client + + result = await mcp.get_tools(mcp_engine=mock_fastmcp) + + assert result == [] + + +class TestMcpListResources: + """Tests for the mcp_list_resources endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.mcp.Client") + async def test_mcp_list_resources_returns_resource_list(self, mock_client_class, mock_fastmcp): + """mcp_list_resources should return list of resources.""" + mock_resource = MagicMock() + mock_resource.model_dump.return_value = {"name": "test_resource", "uri": "resource://test"} + + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.list_resources = AsyncMock(return_value=[mock_resource]) + mock_client.close = AsyncMock() + mock_client_class.return_value = mock_client + + result = await mcp.mcp_list_resources(mcp_engine=mock_fastmcp) + + assert len(result) == 1 + mock_client.close.assert_called_once() + + @pytest.mark.asyncio + @patch("server.api.v1.mcp.Client") + async def test_mcp_list_resources_returns_empty_list(self, mock_client_class, mock_fastmcp): + """mcp_list_resources should return empty list when no resources.""" + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.list_resources = AsyncMock(return_value=[]) + mock_client.close = AsyncMock() + mock_client_class.return_value = mock_client + + result = await mcp.mcp_list_resources(mcp_engine=mock_fastmcp) + + assert result == [] + + +class TestRouterConfiguration: + """Tests for router configuration.""" + + def test_auth_router_exists(self): + """The auth router should be defined.""" + assert hasattr(mcp, "auth") + + def test_auth_router_has_routes(self): + """The auth router should have registered routes.""" + routes = [route.path for route in mcp.auth.routes] + + assert "/client" in routes + assert "/tools" in routes + assert "/resources" in routes + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured.""" + assert hasattr(mcp, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert mcp.logger.name == "api.v1.mcp" diff --git a/test/unit/server/api/v1/test_v1_mcp_prompts.py b/test/unit/server/api/v1/test_v1_mcp_prompts.py new file mode 100644 index 00000000..46c518c8 --- /dev/null +++ b/test/unit/server/api/v1/test_v1_mcp_prompts.py @@ -0,0 +1,229 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/v1/mcp_prompts.py +Tests for MCP prompt management endpoints. +""" + +from unittest.mock import patch, MagicMock, AsyncMock +import pytest +from fastapi import HTTPException + +from server.api.v1 import mcp_prompts + + +class TestMcpListPrompts: + """Tests for the mcp_list_prompts endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.mcp_prompts.utils_mcp.list_prompts") + async def test_mcp_list_prompts_metadata_only(self, mock_list_prompts, mock_fastmcp): + """mcp_list_prompts should return metadata only when full=False.""" + mock_prompt = MagicMock() + mock_prompt.name = "optimizer_test-prompt" + mock_prompt.model_dump.return_value = {"name": "optimizer_test-prompt", "description": "Test"} + mock_list_prompts.return_value = [mock_prompt] + + result = await mcp_prompts.mcp_list_prompts(mcp_engine=mock_fastmcp, full=False) + + assert len(result) == 1 + assert result[0]["name"] == "optimizer_test-prompt" + mock_list_prompts.assert_called_once_with(mock_fastmcp) + + @pytest.mark.asyncio + @patch("server.api.v1.mcp_prompts.utils_settings.get_mcp_prompts_with_overrides") + async def test_mcp_list_prompts_full(self, mock_get_prompts, mock_fastmcp, make_mcp_prompt): + """mcp_list_prompts should return full prompts with text when full=True.""" + mock_prompt = make_mcp_prompt(name="optimizer_test-prompt") + mock_get_prompts.return_value = [mock_prompt] + + result = await mcp_prompts.mcp_list_prompts(mcp_engine=mock_fastmcp, full=True) + + assert len(result) == 1 + assert "text" in result[0] + mock_get_prompts.assert_called_once_with(mock_fastmcp) + + @pytest.mark.asyncio + @patch("server.api.v1.mcp_prompts.utils_mcp.list_prompts") + async def test_mcp_list_prompts_filters_non_optimizer_prompts(self, mock_list_prompts, mock_fastmcp): + """mcp_list_prompts should filter out non-optimizer prompts.""" + optimizer_prompt = MagicMock() + optimizer_prompt.name = "optimizer_test-prompt" + optimizer_prompt.model_dump.return_value = {"name": "optimizer_test-prompt"} + + other_prompt = MagicMock() + other_prompt.name = "other-prompt" + other_prompt.model_dump.return_value = {"name": "other-prompt"} + + mock_list_prompts.return_value = [optimizer_prompt, other_prompt] + + result = await mcp_prompts.mcp_list_prompts(mcp_engine=mock_fastmcp, full=False) + + assert len(result) == 1 + assert result[0]["name"] == "optimizer_test-prompt" + + @pytest.mark.asyncio + @patch("server.api.v1.mcp_prompts.utils_mcp.list_prompts") + async def test_mcp_list_prompts_empty_list(self, mock_list_prompts, mock_fastmcp): + """mcp_list_prompts should return empty list when no prompts.""" + mock_list_prompts.return_value = [] + + result = await mcp_prompts.mcp_list_prompts(mcp_engine=mock_fastmcp, full=False) + + assert result == [] + + +class TestMcpGetPrompt: + """Tests for the mcp_get_prompt endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.mcp_prompts.Client") + async def test_mcp_get_prompt_success(self, mock_client_class, mock_fastmcp): + """mcp_get_prompt should return prompt content.""" + mock_prompt_result = MagicMock() + mock_prompt_result.messages = [{"role": "user", "content": "Test content"}] + + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.get_prompt = AsyncMock(return_value=mock_prompt_result) + mock_client.close = AsyncMock() + mock_client_class.return_value = mock_client + + result = await mcp_prompts.mcp_get_prompt(name="optimizer_test-prompt", mcp_engine=mock_fastmcp) + + assert result == mock_prompt_result + mock_client.get_prompt.assert_called_once_with(name="optimizer_test-prompt") + + @pytest.mark.asyncio + @patch("server.api.v1.mcp_prompts.Client") + async def test_mcp_get_prompt_closes_client(self, mock_client_class, mock_fastmcp): + """mcp_get_prompt should close client after use.""" + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.get_prompt = AsyncMock(return_value=MagicMock()) + mock_client.close = AsyncMock() + mock_client_class.return_value = mock_client + + await mcp_prompts.mcp_get_prompt(name="test-prompt", mcp_engine=mock_fastmcp) + + mock_client.close.assert_called_once() + + +class TestMcpUpdatePrompt: + """Tests for the mcp_update_prompt endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.mcp_prompts.Client") + @patch("server.api.v1.mcp_prompts.cache") + async def test_mcp_update_prompt_success(self, mock_cache, mock_client_class, mock_fastmcp): + """mcp_update_prompt should update prompt and return success.""" + mock_prompt = MagicMock() + mock_prompt.name = "optimizer_test-prompt" + + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.list_prompts = AsyncMock(return_value=[mock_prompt]) + mock_client_class.return_value = mock_client + + payload = {"instructions": "You are a helpful assistant."} + + result = await mcp_prompts.mcp_update_prompt( + name="optimizer_test-prompt", payload=payload, mcp_engine=mock_fastmcp + ) + + assert result["name"] == "optimizer_test-prompt" + assert "updated successfully" in result["message"] + mock_cache.set_override.assert_called_once_with("optimizer_test-prompt", "You are a helpful assistant.") + + @pytest.mark.asyncio + @patch("server.api.v1.mcp_prompts.Client") + async def test_mcp_update_prompt_missing_instructions(self, _mock_client_class, mock_fastmcp): + """mcp_update_prompt should raise 400 when instructions missing.""" + payload = {"other_field": "value"} + + with pytest.raises(HTTPException) as exc_info: + await mcp_prompts.mcp_update_prompt(name="test-prompt", payload=payload, mcp_engine=mock_fastmcp) + + assert exc_info.value.status_code == 400 + assert "instructions" in str(exc_info.value.detail) + + @pytest.mark.asyncio + @patch("server.api.v1.mcp_prompts.Client") + async def test_mcp_update_prompt_not_found(self, mock_client_class, mock_fastmcp): + """mcp_update_prompt should raise 404 when prompt not found.""" + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.list_prompts = AsyncMock(return_value=[]) + mock_client_class.return_value = mock_client + + payload = {"instructions": "New instructions"} + + with pytest.raises(HTTPException) as exc_info: + await mcp_prompts.mcp_update_prompt(name="nonexistent-prompt", payload=payload, mcp_engine=mock_fastmcp) + + assert exc_info.value.status_code == 404 + + @pytest.mark.asyncio + @patch("server.api.v1.mcp_prompts.Client") + @patch("server.api.v1.mcp_prompts.cache") + async def test_mcp_update_prompt_handles_exception(self, mock_cache, mock_client_class, mock_fastmcp): + """mcp_update_prompt should raise 500 on unexpected exception.""" + mock_prompt = MagicMock() + mock_prompt.name = "optimizer_test-prompt" + + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.list_prompts = AsyncMock(return_value=[mock_prompt]) + mock_client_class.return_value = mock_client + + mock_cache.set_override.side_effect = RuntimeError("Cache error") + + payload = {"instructions": "New instructions"} + + with pytest.raises(HTTPException) as exc_info: + await mcp_prompts.mcp_update_prompt(name="optimizer_test-prompt", payload=payload, mcp_engine=mock_fastmcp) + + assert exc_info.value.status_code == 500 + + @pytest.mark.asyncio + async def test_mcp_update_prompt_none_instructions(self, mock_fastmcp): + """mcp_update_prompt should raise 400 when instructions is None.""" + payload = {"instructions": None} + + with pytest.raises(HTTPException) as exc_info: + await mcp_prompts.mcp_update_prompt(name="test-prompt", payload=payload, mcp_engine=mock_fastmcp) + + assert exc_info.value.status_code == 400 + + +class TestRouterConfiguration: + """Tests for router configuration.""" + + def test_auth_router_exists(self): + """The auth router should be defined.""" + assert hasattr(mcp_prompts, "auth") + + def test_auth_router_has_routes(self): + """The auth router should have registered routes.""" + routes = [route.path for route in mcp_prompts.auth.routes] + + assert "/prompts" in routes + assert "/prompts/{name}" in routes + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured.""" + assert hasattr(mcp_prompts, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert mcp_prompts.logger.name == "api.v1.mcp_prompts" diff --git a/test/unit/server/api/v1/test_v1_models.py b/test/unit/server/api/v1/test_v1_models.py new file mode 100644 index 00000000..6a4f721e --- /dev/null +++ b/test/unit/server/api/v1/test_v1_models.py @@ -0,0 +1,254 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/v1/models.py +Tests for model configuration endpoints. +""" + +import json +from unittest.mock import patch + +import pytest +from fastapi import HTTPException + +from server.api.v1 import models +from server.api.utils import models as utils_models + + +class TestModelsList: + """Tests for the models_list endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.models.utils_models.get") + async def test_models_list_returns_all_models(self, mock_get, make_model): + """models_list should return all configured models.""" + model_list = [ + make_model(model_id="gpt-4", provider="openai"), + make_model(model_id="claude-3", provider="anthropic"), + ] + mock_get.return_value = model_list + + result = await models.models_list() + + assert result == model_list + mock_get.assert_called_once() + + @pytest.mark.asyncio + @patch("server.api.v1.models.utils_models.get") + async def test_models_list_with_type_filter(self, mock_get): + """models_list should filter by model type when provided.""" + mock_get.return_value = [] + + await models.models_list(model_type="ll") + + mock_get.assert_called_once() + # Verify the model_type was passed (FastAPI Query wraps the value) + call_kwargs = mock_get.call_args.kwargs + assert call_kwargs.get("model_type") == "ll" + + @pytest.mark.asyncio + @patch("server.api.v1.models.utils_models.get") + async def test_models_list_with_include_disabled(self, mock_get): + """models_list should include disabled models when requested.""" + mock_get.return_value = [] + + await models.models_list(include_disabled=True) + + mock_get.assert_called_once() + # Verify the include_disabled was passed + call_kwargs = mock_get.call_args.kwargs + assert call_kwargs.get("include_disabled") is True + + +class TestModelsSupported: + """Tests for the models_supported endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.models.utils_models.get_supported") + async def test_models_supported_returns_supported_list(self, mock_get_supported): + """models_supported should return list of supported models.""" + supported_models = [ + {"provider": "openai", "models": ["gpt-4", "gpt-4o"]}, + ] + mock_get_supported.return_value = supported_models + + result = await models.models_supported(model_provider="openai") + + assert result == supported_models + mock_get_supported.assert_called_once() + + @pytest.mark.asyncio + @patch("server.api.v1.models.utils_models.get_supported") + async def test_models_supported_filters_by_type(self, mock_get_supported): + """models_supported should filter by model type when provided.""" + mock_get_supported.return_value = [] + + await models.models_supported(model_provider="openai", model_type="ll") + + mock_get_supported.assert_called_once() + call_kwargs = mock_get_supported.call_args.kwargs + assert call_kwargs.get("model_provider") == "openai" + assert call_kwargs.get("model_type") == "ll" + + +class TestModelsGet: + """Tests for the models_get endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.models.utils_models.get") + async def test_models_get_returns_single_model(self, mock_get, make_model): + """models_get should return a single model by ID.""" + model = make_model(model_id="gpt-4", provider="openai") + mock_get.return_value = (model,) # Returns a tuple that unpacks + + result = await models.models_get(model_provider="openai", model_id="gpt-4") + + assert result == model + mock_get.assert_called_once_with(model_provider="openai", model_id="gpt-4") + + @pytest.mark.asyncio + @patch("server.api.v1.models.utils_models.get") + async def test_models_get_raises_404_when_not_found(self, mock_get): + """models_get should raise 404 when model not found.""" + mock_get.side_effect = utils_models.UnknownModelError("Model not found") + + with pytest.raises(HTTPException) as exc_info: + await models.models_get(model_provider="openai", model_id="nonexistent") + + assert exc_info.value.status_code == 404 + + @pytest.mark.asyncio + @patch("server.api.v1.models.utils_models.get") + async def test_models_get_raises_404_on_multiple_results(self, mock_get, make_model): + """models_get should raise 404 when multiple models match.""" + # Returning a tuple with more than 1 element causes ValueError on unpack + mock_get.return_value = (make_model(), make_model()) + + with pytest.raises(HTTPException) as exc_info: + await models.models_get(model_provider="openai", model_id="gpt-4") + + assert exc_info.value.status_code == 404 + + +class TestModelsUpdate: + """Tests for the models_update endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.models.utils_models.update") + async def test_models_update_returns_updated_model(self, mock_update, make_model): + """models_update should return the updated model.""" + updated_model = make_model(model_id="gpt-4", provider="openai", enabled=False) + mock_update.return_value = updated_model + + payload = make_model(model_id="gpt-4", provider="openai") + result = await models.models_update(payload=payload) + + assert result == updated_model + mock_update.assert_called_once_with(payload=payload) + + @pytest.mark.asyncio + @patch("server.api.v1.models.utils_models.update") + async def test_models_update_raises_404_when_not_found(self, mock_update, make_model): + """models_update should raise 404 when model not found.""" + mock_update.side_effect = utils_models.UnknownModelError("Model not found") + + payload = make_model(model_id="nonexistent", provider="openai") + + with pytest.raises(HTTPException) as exc_info: + await models.models_update(payload=payload) + + assert exc_info.value.status_code == 404 + + @pytest.mark.asyncio + @patch("server.api.v1.models.utils_models.update") + async def test_models_update_raises_422_on_unreachable_url(self, mock_update, make_model): + """models_update should raise 422 when API URL is unreachable.""" + mock_update.side_effect = utils_models.URLUnreachableError("URL unreachable") + + payload = make_model(model_id="gpt-4", provider="openai") + + with pytest.raises(HTTPException) as exc_info: + await models.models_update(payload=payload) + + assert exc_info.value.status_code == 422 + + +class TestModelsCreate: + """Tests for the models_create endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.models.utils_models.create") + async def test_models_create_returns_new_model(self, mock_create, make_model): + """models_create should return newly created model.""" + new_model = make_model(model_id="new-model", provider="openai") + mock_create.return_value = new_model + + result = await models.models_create(payload=make_model(model_id="new-model", provider="openai")) + + assert result == new_model + + @pytest.mark.asyncio + @patch("server.api.v1.models.utils_models.create") + async def test_models_create_raises_409_on_duplicate(self, mock_create, make_model): + """models_create should raise 409 when model already exists.""" + mock_create.side_effect = utils_models.ExistsModelError("Model already exists") + + with pytest.raises(HTTPException) as exc_info: + await models.models_create(payload=make_model()) + + assert exc_info.value.status_code == 409 + + +class TestModelsDelete: + """Tests for the models_delete endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.models.utils_models.delete") + async def test_models_delete_returns_200_on_success(self, mock_delete): + """models_delete should return 200 status on success.""" + mock_delete.return_value = None + + result = await models.models_delete(model_provider="openai", model_id="gpt-4") + + assert result.status_code == 200 + mock_delete.assert_called_once_with(model_provider="openai", model_id="gpt-4") + + @pytest.mark.asyncio + @patch("server.api.v1.models.utils_models.delete") + async def test_models_delete_response_contains_message(self, mock_delete): + """models_delete should return message with model name.""" + mock_delete.return_value = None + + result = await models.models_delete(model_provider="openai", model_id="gpt-4") + + body = json.loads(result.body) + assert "openai/gpt-4" in body["message"] + + +class TestRouterConfiguration: + """Tests for router configuration.""" + + def test_auth_router_exists(self): + """The auth router should be defined.""" + assert hasattr(models, "auth") + + def test_auth_router_has_routes(self): + """The auth router should have registered routes.""" + routes = [route.path for route in models.auth.routes] + + assert "" in routes + assert "/supported" in routes + assert "/{model_provider}/{model_id:path}" in routes + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured.""" + assert hasattr(models, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert models.logger.name == "endpoints.v1.models" diff --git a/test/unit/server/api/v1/test_v1_oci.py b/test/unit/server/api/v1/test_v1_oci.py new file mode 100644 index 00000000..4402e96c --- /dev/null +++ b/test/unit/server/api/v1/test_v1_oci.py @@ -0,0 +1,362 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/v1/oci.py +Tests for OCI configuration and resource endpoints. +""" + +# pylint: disable=too-few-public-methods + +from unittest.mock import patch +import pytest +from fastapi import HTTPException + +from server.api.v1 import oci +from server.api.utils.oci import OciException + + +class TestOciList: + """Tests for the oci_list endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.oci.utils_oci.get") + async def test_oci_list_returns_all_configs(self, mock_get, make_oci_config): + """oci_list should return all OCI configurations.""" + configs = [make_oci_config(auth_profile="DEFAULT"), make_oci_config(auth_profile="PROD")] + mock_get.return_value = configs + + result = await oci.oci_list() + + assert result == configs + mock_get.assert_called_once_with() + + @pytest.mark.asyncio + @patch("server.api.v1.oci.utils_oci.get") + async def test_oci_list_raises_404_on_value_error(self, mock_get): + """oci_list should raise 404 when ValueError occurs.""" + mock_get.side_effect = ValueError("No configs found") + + with pytest.raises(HTTPException) as exc_info: + await oci.oci_list() + + assert exc_info.value.status_code == 404 + assert "OCI:" in str(exc_info.value.detail) + + +class TestOciGet: + """Tests for the oci_get endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.oci.utils_oci.get") + async def test_oci_get_returns_single_config(self, mock_get, make_oci_config): + """oci_get should return a single OCI config by profile.""" + config = make_oci_config(auth_profile="DEFAULT") + mock_get.return_value = config + + result = await oci.oci_get(auth_profile="DEFAULT") + + assert result == config + mock_get.assert_called_once_with(auth_profile="DEFAULT") + + @pytest.mark.asyncio + @patch("server.api.v1.oci.utils_oci.get") + async def test_oci_get_raises_404_when_not_found(self, mock_get): + """oci_get should raise 404 when profile not found.""" + mock_get.side_effect = ValueError("Profile not found") + + with pytest.raises(HTTPException) as exc_info: + await oci.oci_get(auth_profile="NONEXISTENT") + + assert exc_info.value.status_code == 404 + + +class TestOciListRegions: + """Tests for the oci_list_regions endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.oci.oci_get") + @patch("server.api.v1.oci.utils_oci.get_regions") + async def test_oci_list_regions_success(self, mock_get_regions, mock_oci_get, make_oci_config): + """oci_list_regions should return list of regions.""" + config = make_oci_config() + mock_oci_get.return_value = config + mock_get_regions.return_value = ["us-ashburn-1", "us-phoenix-1"] + + result = await oci.oci_list_regions(auth_profile="DEFAULT") + + assert result == ["us-ashburn-1", "us-phoenix-1"] + mock_get_regions.assert_called_once_with(config) + + @pytest.mark.asyncio + @patch("server.api.v1.oci.oci_get") + @patch("server.api.v1.oci.utils_oci.get_regions") + async def test_oci_list_regions_raises_on_oci_exception(self, mock_get_regions, mock_oci_get, make_oci_config): + """oci_list_regions should raise HTTPException on OciException.""" + mock_oci_get.return_value = make_oci_config() + mock_get_regions.side_effect = OciException(status_code=401, detail="Unauthorized") + + with pytest.raises(HTTPException) as exc_info: + await oci.oci_list_regions(auth_profile="DEFAULT") + + assert exc_info.value.status_code == 401 + + +class TestOciListGenai: + """Tests for the oci_list_genai endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.oci.oci_get") + @patch("server.api.v1.oci.utils_oci.get_genai_models") + async def test_oci_list_genai_success(self, mock_get_genai, mock_oci_get, make_oci_config): + """oci_list_genai should return list of GenAI models.""" + config = make_oci_config() + mock_oci_get.return_value = config + mock_get_genai.return_value = [{"name": "cohere.command"}, {"name": "meta.llama"}] + + result = await oci.oci_list_genai(auth_profile="DEFAULT") + + assert len(result) == 2 + mock_get_genai.assert_called_once_with(config, regional=False) + + @pytest.mark.asyncio + @patch("server.api.v1.oci.oci_get") + @patch("server.api.v1.oci.utils_oci.get_genai_models") + async def test_oci_list_genai_raises_on_oci_exception(self, mock_get_genai, mock_oci_get, make_oci_config): + """oci_list_genai should raise HTTPException on OciException.""" + mock_oci_get.return_value = make_oci_config() + mock_get_genai.side_effect = OciException(status_code=403, detail="Forbidden") + + with pytest.raises(HTTPException) as exc_info: + await oci.oci_list_genai(auth_profile="DEFAULT") + + assert exc_info.value.status_code == 403 + + +class TestOciListCompartments: + """Tests for the oci_list_compartments endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.oci.oci_get") + @patch("server.api.v1.oci.utils_oci.get_compartments") + async def test_oci_list_compartments_success(self, mock_get_compartments, mock_oci_get, make_oci_config): + """oci_list_compartments should return compartment hierarchy.""" + config = make_oci_config() + mock_oci_get.return_value = config + compartments = {"root": {"name": "root", "children": []}} + mock_get_compartments.return_value = compartments + + result = await oci.oci_list_compartments(auth_profile="DEFAULT") + + assert result == compartments + mock_get_compartments.assert_called_once_with(config) + + @pytest.mark.asyncio + @patch("server.api.v1.oci.oci_get") + @patch("server.api.v1.oci.utils_oci.get_compartments") + async def test_oci_list_compartments_raises_on_oci_exception( + self, mock_get_compartments, mock_oci_get, make_oci_config + ): + """oci_list_compartments should raise HTTPException on OciException.""" + mock_oci_get.return_value = make_oci_config() + mock_get_compartments.side_effect = OciException(status_code=500, detail="Internal error") + + with pytest.raises(HTTPException) as exc_info: + await oci.oci_list_compartments(auth_profile="DEFAULT") + + assert exc_info.value.status_code == 500 + + +class TestOciListBuckets: + """Tests for the oci_list_buckets endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.oci.oci_get") + @patch("server.api.v1.oci.utils_oci.get_buckets") + async def test_oci_list_buckets_success(self, mock_get_buckets, mock_oci_get, make_oci_config): + """oci_list_buckets should return list of buckets.""" + config = make_oci_config() + mock_oci_get.return_value = config + mock_get_buckets.return_value = ["bucket1", "bucket2"] + compartment_ocid = "ocid1.compartment.oc1..aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + + result = await oci.oci_list_buckets(auth_profile="DEFAULT", compartment_ocid=compartment_ocid) + + assert result == ["bucket1", "bucket2"] + mock_get_buckets.assert_called_once_with(compartment_ocid, config) + + @pytest.mark.asyncio + @patch("server.api.v1.oci.oci_get") + @patch("server.api.v1.oci.utils_oci.get_buckets") + async def test_oci_list_buckets_raises_on_oci_exception(self, mock_get_buckets, mock_oci_get, make_oci_config): + """oci_list_buckets should raise HTTPException on OciException.""" + mock_oci_get.return_value = make_oci_config() + mock_get_buckets.side_effect = OciException(status_code=404, detail="Bucket not found") + compartment_ocid = "ocid1.compartment.oc1..aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + + with pytest.raises(HTTPException) as exc_info: + await oci.oci_list_buckets(auth_profile="DEFAULT", compartment_ocid=compartment_ocid) + + assert exc_info.value.status_code == 404 + + +class TestOciListBucketObjects: + """Tests for the oci_list_bucket_objects endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.oci.oci_get") + @patch("server.api.v1.oci.utils_oci.get_bucket_objects") + async def test_oci_list_bucket_objects_success(self, mock_get_objects, mock_oci_get, make_oci_config): + """oci_list_bucket_objects should return list of objects.""" + config = make_oci_config() + mock_oci_get.return_value = config + mock_get_objects.return_value = ["file1.pdf", "file2.txt"] + + result = await oci.oci_list_bucket_objects(auth_profile="DEFAULT", bucket_name="my-bucket") + + assert result == ["file1.pdf", "file2.txt"] + mock_get_objects.assert_called_once_with("my-bucket", config) + + @pytest.mark.asyncio + @patch("server.api.v1.oci.oci_get") + @patch("server.api.v1.oci.utils_oci.get_bucket_objects") + async def test_oci_list_bucket_objects_raises_on_oci_exception( + self, mock_get_objects, mock_oci_get, make_oci_config + ): + """oci_list_bucket_objects should raise HTTPException on OciException.""" + mock_oci_get.return_value = make_oci_config() + mock_get_objects.side_effect = OciException(status_code=403, detail="Access denied") + + with pytest.raises(HTTPException) as exc_info: + await oci.oci_list_bucket_objects(auth_profile="DEFAULT", bucket_name="my-bucket") + + assert exc_info.value.status_code == 403 + + +class TestOciProfileUpdate: + """Tests for the oci_profile_update endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.oci.oci_get") + @patch("server.api.v1.oci.utils_oci.get_namespace") + async def test_oci_profile_update_success(self, mock_get_namespace, mock_oci_get, make_oci_config): + """oci_profile_update should update and return config.""" + config = make_oci_config(auth_profile="DEFAULT") + mock_oci_get.return_value = config + mock_get_namespace.return_value = "test-namespace" + + payload = make_oci_config(auth_profile="DEFAULT", genai_region="us-phoenix-1") + + result = await oci.oci_profile_update(auth_profile="DEFAULT", payload=payload) + + assert result.namespace == "test-namespace" + assert result.genai_region == "us-phoenix-1" + + @pytest.mark.asyncio + @patch("server.api.v1.oci.oci_get") + @patch("server.api.v1.oci.utils_oci.get_namespace") + async def test_oci_profile_update_raises_on_oci_exception(self, mock_get_namespace, mock_oci_get, make_oci_config): + """oci_profile_update should raise HTTPException on OciException.""" + config = make_oci_config() + mock_oci_get.return_value = config + mock_get_namespace.side_effect = OciException(status_code=401, detail="Invalid credentials") + + with pytest.raises(HTTPException) as exc_info: + await oci.oci_profile_update(auth_profile="DEFAULT", payload=make_oci_config()) + + assert exc_info.value.status_code == 401 + assert config.namespace is None + + +class TestOciDownloadObjects: + """Tests for the oci_download_objects endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.oci.oci_get") + @patch("server.api.v1.oci.utils_embed.get_temp_directory") + @patch("server.api.v1.oci.utils_oci.get_object") + async def test_oci_download_objects_success( + self, mock_get_object, mock_get_temp_dir, mock_oci_get, make_oci_config, tmp_path + ): + """oci_download_objects should download files and return list.""" + config = make_oci_config() + mock_oci_get.return_value = config + mock_get_temp_dir.return_value = tmp_path + + # Create test files + (tmp_path / "file1.pdf").touch() + (tmp_path / "file2.txt").touch() + + result = await oci.oci_download_objects( + bucket_name="my-bucket", + auth_profile="DEFAULT", + request=["file1.pdf", "file2.txt"], + client="test_client", + ) + + assert result.status_code == 200 + assert mock_get_object.call_count == 2 + + +class TestOciCreateGenaiModels: + """Tests for the oci_create_genai_models endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.oci.oci_get") + @patch("server.api.v1.oci.utils_models.create_genai") + async def test_oci_create_genai_models_success(self, mock_create_genai, mock_oci_get, make_oci_config, make_model): + """oci_create_genai_models should create and return models.""" + config = make_oci_config() + mock_oci_get.return_value = config + models_list = [make_model(model_id="cohere.command", provider="oci")] + mock_create_genai.return_value = models_list + + result = await oci.oci_create_genai_models(auth_profile="DEFAULT") + + assert result == models_list + mock_create_genai.assert_called_once_with(config) + + @pytest.mark.asyncio + @patch("server.api.v1.oci.oci_get") + @patch("server.api.v1.oci.utils_models.create_genai") + async def test_oci_create_genai_models_raises_on_oci_exception( + self, mock_create_genai, mock_oci_get, make_oci_config + ): + """oci_create_genai_models should raise HTTPException on OciException.""" + mock_oci_get.return_value = make_oci_config() + mock_create_genai.side_effect = OciException(status_code=500, detail="GenAI service error") + + with pytest.raises(HTTPException) as exc_info: + await oci.oci_create_genai_models(auth_profile="DEFAULT") + + assert exc_info.value.status_code == 500 + + +class TestRouterConfiguration: + """Tests for router configuration.""" + + def test_auth_router_exists(self): + """The auth router should be defined.""" + assert hasattr(oci, "auth") + + def test_auth_router_has_routes(self): + """The auth router should have registered routes.""" + routes = [route.path for route in oci.auth.routes] + + assert "" in routes + assert "/{auth_profile}" in routes + assert "/regions/{auth_profile}" in routes + assert "/genai/{auth_profile}" in routes + assert "/compartments/{auth_profile}" in routes + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured.""" + assert hasattr(oci, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert oci.logger.name == "endpoints.v1.oci" diff --git a/test/unit/server/api/v1/test_v1_probes.py b/test/unit/server/api/v1/test_v1_probes.py new file mode 100644 index 00000000..e716a5ff --- /dev/null +++ b/test/unit/server/api/v1/test_v1_probes.py @@ -0,0 +1,129 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/v1/probes.py +Tests for Kubernetes health probe endpoints. +""" + +import asyncio +from unittest.mock import MagicMock + +import pytest + +from server.api.v1 import probes + + +class TestGetMcp: + """Tests for the get_mcp dependency function.""" + + def test_get_mcp_returns_fastmcp_app(self): + """get_mcp should return the FastMCP app from request state.""" + mock_request = MagicMock() + mock_fastmcp = MagicMock() + mock_request.app.state.fastmcp_app = mock_fastmcp + + result = probes.get_mcp(mock_request) + + assert result == mock_fastmcp + + def test_get_mcp_accesses_correct_state_attribute(self): + """get_mcp should access app.state.fastmcp_app.""" + mock_request = MagicMock() + + probes.get_mcp(mock_request) + + _ = mock_request.app.state.fastmcp_app # Verify attribute access + + +class TestLivenessProbe: + """Tests for the liveness_probe endpoint.""" + + @pytest.mark.asyncio + async def test_liveness_probe_returns_alive(self): + """liveness_probe should return alive status.""" + result = await probes.liveness_probe() + + assert result == {"status": "alive"} + + @pytest.mark.asyncio + async def test_liveness_probe_is_async(self): + """liveness_probe should be an async function.""" + assert asyncio.iscoroutinefunction(probes.liveness_probe) + + +class TestReadinessProbe: + """Tests for the readiness_probe endpoint.""" + + @pytest.mark.asyncio + async def test_readiness_probe_returns_ready(self): + """readiness_probe should return ready status.""" + result = await probes.readiness_probe() + + assert result == {"status": "ready"} + + @pytest.mark.asyncio + async def test_readiness_probe_is_async(self): + """readiness_probe should be an async function.""" + assert asyncio.iscoroutinefunction(probes.readiness_probe) + + +class TestMcpHealthz: + """Tests for the mcp_healthz endpoint.""" + + def test_mcp_healthz_returns_ready_status(self): + """mcp_healthz should return ready status with server info.""" + mock_fastmcp = MagicMock() + mock_fastmcp.__dict__["_mcp_server"] = MagicMock() + mock_fastmcp.__dict__["_mcp_server"].__dict__ = { + "name": "test-server", + "version": "1.0.0", + } + mock_fastmcp.available_tools = ["tool1", "tool2"] + + result = probes.mcp_healthz(mock_fastmcp) + + assert result["status"] == "ready" + assert result["name"] == "test-server" + assert result["version"] == "1.0.0" + assert result["available_tools"] == 2 + + def test_mcp_healthz_returns_not_ready_when_none(self): + """mcp_healthz should return not ready when mcp_engine is None.""" + result = probes.mcp_healthz(None) + + assert result["status"] == "not ready" + + def test_mcp_healthz_with_no_available_tools(self): + """mcp_healthz should handle missing available_tools attribute.""" + mock_fastmcp = MagicMock(spec=[]) # No available_tools attribute + mock_fastmcp.__dict__["_mcp_server"] = MagicMock() + mock_fastmcp.__dict__["_mcp_server"].__dict__ = { + "name": "test-server", + "version": "1.0.0", + } + + result = probes.mcp_healthz(mock_fastmcp) + + assert result["status"] == "ready" + assert result["available_tools"] == 0 + + def test_mcp_healthz_is_not_async(self): + """mcp_healthz should be a sync function.""" + assert not asyncio.iscoroutinefunction(probes.mcp_healthz) + + +class TestRouterConfiguration: + """Tests for router configuration.""" + + def test_noauth_router_exists(self): + """The noauth router should be defined.""" + assert hasattr(probes, "noauth") + + def test_noauth_router_has_routes(self): + """The noauth router should have registered routes.""" + routes = [route.path for route in probes.noauth.routes] + + assert "/liveness" in routes + assert "/readiness" in routes + assert "/mcp/healthz" in routes diff --git a/test/unit/server/api/v1/test_v1_settings.py b/test/unit/server/api/v1/test_v1_settings.py new file mode 100644 index 00000000..348613a1 --- /dev/null +++ b/test/unit/server/api/v1/test_v1_settings.py @@ -0,0 +1,326 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/v1/settings.py +Tests for client settings management endpoints. +""" + +from unittest.mock import patch, MagicMock +from io import BytesIO +import json +import pytest +from fastapi import HTTPException, UploadFile +from fastapi.responses import JSONResponse + +from server.api.v1 import settings + + +class TestSettingsGet: + """Tests for the settings_get endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.settings.utils_settings.get_client") + async def test_settings_get_returns_client_settings(self, mock_get_client, make_settings): + """settings_get should return client settings.""" + client_settings = make_settings(client="test_client") + mock_get_client.return_value = client_settings + + mock_request = MagicMock() + + result = await settings.settings_get( + request=mock_request, client="test_client", full_config=False, incl_sensitive=False, incl_readonly=False + ) + + assert result == client_settings + mock_get_client.assert_called_once_with("test_client") + + @pytest.mark.asyncio + @patch("server.api.v1.settings.utils_settings.get_client") + async def test_settings_get_raises_404_when_not_found(self, mock_get_client): + """settings_get should raise 404 when client not found.""" + mock_get_client.side_effect = ValueError("Client not found") + + mock_request = MagicMock() + + with pytest.raises(HTTPException) as exc_info: + await settings.settings_get( + request=mock_request, + client="nonexistent", + full_config=False, + incl_sensitive=False, + incl_readonly=False, + ) + + assert exc_info.value.status_code == 404 + + @pytest.mark.asyncio + @patch("server.api.v1.settings.utils_settings.get_client") + @patch("server.api.v1.settings.utils_settings.get_server") + async def test_settings_get_full_config(self, mock_get_server, mock_get_client, make_settings, mock_fastmcp): + """settings_get should return full config when requested.""" + client_settings = make_settings(client="test_client") + mock_get_client.return_value = client_settings + mock_get_server.return_value = { + "database_configs": [], + "model_configs": [], + "oci_configs": [], + "prompt_configs": [], + } + + mock_request = MagicMock() + mock_request.app.state.fastmcp_app = mock_fastmcp + + result = await settings.settings_get( + request=mock_request, client="test_client", full_config=True, incl_sensitive=False, incl_readonly=False + ) + + assert isinstance(result, JSONResponse) + mock_get_server.assert_called_once() + + +class TestSettingsUpdate: + """Tests for the settings_update endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.settings.utils_settings.update_client") + async def test_settings_update_success(self, mock_update_client, make_settings): + """settings_update should update and return settings.""" + updated_settings = make_settings(client="test_client", temperature=0.9) + mock_update_client.return_value = updated_settings + + payload = make_settings(client="test_client", temperature=0.9) + + result = await settings.settings_update(payload=payload, client="test_client") + + assert result == updated_settings + mock_update_client.assert_called_once_with(payload, "test_client") + + @pytest.mark.asyncio + @patch("server.api.v1.settings.utils_settings.update_client") + async def test_settings_update_raises_404_when_not_found(self, mock_update_client, make_settings): + """settings_update should raise 404 when client not found.""" + mock_update_client.side_effect = ValueError("Client not found") + + payload = make_settings(client="nonexistent") + + with pytest.raises(HTTPException) as exc_info: + await settings.settings_update(payload=payload, client="nonexistent") + + assert exc_info.value.status_code == 404 + + +class TestSettingsCreate: + """Tests for the settings_create endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.settings.utils_settings.create_client") + async def test_settings_create_success(self, mock_create_client, make_settings): + """settings_create should create and return new settings.""" + new_settings = make_settings(client="new_client") + mock_create_client.return_value = new_settings + + result = await settings.settings_create(client="new_client") + + assert result == new_settings + mock_create_client.assert_called_once_with("new_client") + + @pytest.mark.asyncio + @patch("server.api.v1.settings.utils_settings.create_client") + async def test_settings_create_raises_409_when_exists(self, mock_create_client): + """settings_create should raise 409 when client already exists.""" + mock_create_client.side_effect = ValueError("Client already exists") + + with pytest.raises(HTTPException) as exc_info: + await settings.settings_create(client="existing_client") + + assert exc_info.value.status_code == 409 + + +class TestLoadSettingsFromFile: + """Tests for the load_settings_from_file endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.settings.utils_settings.create_client") + @patch("server.api.v1.settings.utils_settings.load_config_from_json_data") + async def test_load_settings_from_file_success(self, mock_load_config, mock_create_client): + """load_settings_from_file should load config from JSON file.""" + mock_create_client.return_value = MagicMock() + mock_load_config.return_value = None + + config_data = {"client_settings": {"client": "test"}, "database_configs": []} + file_content = json.dumps(config_data).encode() + mock_file = UploadFile(file=BytesIO(file_content), filename="config.json") + + result = await settings.load_settings_from_file(client="test_client", file=mock_file) + + assert result["message"] == "Configuration loaded successfully." + mock_load_config.assert_called_once() + + @pytest.mark.asyncio + @patch("server.api.v1.settings.utils_settings.create_client") + async def test_load_settings_from_file_wrong_extension(self, mock_create_client): + """load_settings_from_file should raise error for non-JSON files. + + Note: Due to the generic exception handler in the source code, + HTTPException(400) is caught and wrapped in HTTPException(500). + """ + mock_create_client.return_value = MagicMock() + + mock_file = UploadFile(file=BytesIO(b"data"), filename="config.txt") + + with pytest.raises(HTTPException) as exc_info: + await settings.load_settings_from_file(client="test_client", file=mock_file) + + # The 400 HTTPException gets caught by generic exception handler and wrapped in 500 + assert exc_info.value.status_code == 500 + assert "JSON" in str(exc_info.value.detail) + + @pytest.mark.asyncio + @patch("server.api.v1.settings.utils_settings.create_client") + async def test_load_settings_from_file_invalid_json(self, mock_create_client): + """load_settings_from_file should raise 400 for invalid JSON.""" + mock_create_client.return_value = MagicMock() + + mock_file = UploadFile(file=BytesIO(b"not valid json"), filename="config.json") + + with pytest.raises(HTTPException) as exc_info: + await settings.load_settings_from_file(client="test_client", file=mock_file) + + assert exc_info.value.status_code == 400 + assert "Invalid JSON" in str(exc_info.value.detail) + + @pytest.mark.asyncio + @patch("server.api.v1.settings.utils_settings.create_client") + @patch("server.api.v1.settings.utils_settings.load_config_from_json_data") + async def test_load_settings_from_file_key_error(self, mock_load_config, mock_create_client): + """load_settings_from_file should raise 400 on KeyError.""" + mock_create_client.return_value = MagicMock() + mock_load_config.side_effect = KeyError("Missing required key") + + config_data = {"incomplete": "data"} + file_content = json.dumps(config_data).encode() + mock_file = UploadFile(file=BytesIO(file_content), filename="config.json") + + with pytest.raises(HTTPException) as exc_info: + await settings.load_settings_from_file(client="test_client", file=mock_file) + + assert exc_info.value.status_code == 400 + + @pytest.mark.asyncio + @patch("server.api.v1.settings.utils_settings.create_client") + @patch("server.api.v1.settings.utils_settings.load_config_from_json_data") + async def test_load_settings_from_file_handles_existing_client(self, mock_load_config, mock_create_client): + """load_settings_from_file should continue if client already exists.""" + mock_create_client.side_effect = ValueError("Client already exists") + mock_load_config.return_value = None + + config_data = {"client_settings": {"client": "test"}} + file_content = json.dumps(config_data).encode() + mock_file = UploadFile(file=BytesIO(file_content), filename="config.json") + + result = await settings.load_settings_from_file(client="test_client", file=mock_file) + + assert result["message"] == "Configuration loaded successfully." + + +class TestLoadSettingsFromJson: + """Tests for the load_settings_from_json endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.settings.utils_settings.create_client") + @patch("server.api.v1.settings.utils_settings.load_config_from_json_data") + async def test_load_settings_from_json_success(self, mock_load_config, mock_create_client, make_configuration): + """load_settings_from_json should load config from JSON payload.""" + mock_create_client.return_value = MagicMock() + mock_load_config.return_value = None + + payload = make_configuration(client="test_client") + + result = await settings.load_settings_from_json(client="test_client", payload=payload) + + assert result["message"] == "Configuration loaded successfully." + mock_load_config.assert_called_once() + + @pytest.mark.asyncio + @patch("server.api.v1.settings.utils_settings.create_client") + @patch("server.api.v1.settings.utils_settings.load_config_from_json_data") + async def test_load_settings_from_json_key_error(self, mock_load_config, mock_create_client, make_configuration): + """load_settings_from_json should raise 400 on KeyError.""" + mock_create_client.return_value = MagicMock() + mock_load_config.side_effect = KeyError("Missing required key") + + payload = make_configuration(client="test_client") + + with pytest.raises(HTTPException) as exc_info: + await settings.load_settings_from_json(client="test_client", payload=payload) + + assert exc_info.value.status_code == 400 + + @pytest.mark.asyncio + @patch("server.api.v1.settings.utils_settings.create_client") + @patch("server.api.v1.settings.utils_settings.load_config_from_json_data") + async def test_load_settings_from_json_handles_existing_client( + self, mock_load_config, mock_create_client, make_configuration + ): + """load_settings_from_json should continue if client already exists.""" + mock_create_client.side_effect = ValueError("Client already exists") + mock_load_config.return_value = None + + payload = make_configuration(client="test_client") + + result = await settings.load_settings_from_json(client="test_client", payload=payload) + + assert result["message"] == "Configuration loaded successfully." + + +class TestIncludeParams: # pylint: disable=protected-access + """Tests for the include parameter dependencies.""" + + def test_incl_sensitive_param_default(self): + """_incl_sensitive_param should default to False.""" + result = settings._incl_sensitive_param(incl_sensitive=False) + assert result is False + + def test_incl_sensitive_param_true(self): + """_incl_sensitive_param should return True when set.""" + result = settings._incl_sensitive_param(incl_sensitive=True) + assert result is True + + def test_incl_readonly_param_default(self): + """_incl_readonly_param should default to False.""" + result = settings._incl_readonly_param(incl_readonly=False) + assert result is False + + def test_incl_readonly_param_true(self): + """_incl_readonly_param should return True when set.""" + result = settings._incl_readonly_param(incl_readonly=True) + assert result is True + + +class TestRouterConfiguration: + """Tests for router configuration.""" + + def test_auth_router_exists(self): + """The auth router should be defined.""" + assert hasattr(settings, "auth") + + def test_auth_router_has_routes(self): + """The auth router should have registered routes.""" + routes = [route.path for route in settings.auth.routes] + + assert "" in routes # Get, Update, Create + assert "/load/file" in routes + assert "/load/json" in routes + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured.""" + assert hasattr(settings, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert settings.logger.name == "endpoints.v1.settings" diff --git a/test/unit/server/api/v1/test_v1_testbed.py b/test/unit/server/api/v1/test_v1_testbed.py new file mode 100644 index 00000000..14edd3dd --- /dev/null +++ b/test/unit/server/api/v1/test_v1_testbed.py @@ -0,0 +1,305 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/v1/testbed.py +Tests for Q&A testbed and evaluation endpoints. +""" +# pylint: disable=protected-access,too-few-public-methods + +from unittest.mock import patch, MagicMock +from io import BytesIO +import pytest +from fastapi import HTTPException, UploadFile +import litellm + +from server.api.v1 import testbed +from common.schema import TestSets, TestSetQA, Evaluation, EvaluationReport + + +class TestTestbedTestsets: + """Tests for the testbed_testsets endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.testbed.utils_databases.get_client_database") + @patch("server.api.v1.testbed.utils_testbed.get_testsets") + async def test_testbed_testsets_returns_list(self, mock_get_testsets, mock_get_db, mock_db_connection): + """testbed_testsets should return list of testsets.""" + mock_db = MagicMock() + mock_db.connection = mock_db_connection + mock_get_db.return_value = mock_db + + mock_testsets = [ + TestSets(tid="TS001", name="Test Set 1", created="2024-01-01"), + TestSets(tid="TS002", name="Test Set 2", created="2024-01-02"), + ] + mock_get_testsets.return_value = mock_testsets + + result = await testbed.testbed_testsets(client="test_client") + + assert result == mock_testsets + mock_get_testsets.assert_called_once_with(db_conn=mock_db_connection) + + @pytest.mark.asyncio + @patch("server.api.v1.testbed.utils_databases.get_client_database") + @patch("server.api.v1.testbed.utils_testbed.get_testsets") + async def test_testbed_testsets_empty_list(self, mock_get_testsets, mock_get_db, mock_db_connection): + """testbed_testsets should return empty list when no testsets.""" + mock_db = MagicMock() + mock_db.connection = mock_db_connection + mock_get_db.return_value = mock_db + mock_get_testsets.return_value = [] + + result = await testbed.testbed_testsets(client="test_client") + + assert result == [] + + +class TestTestbedEvaluations: + """Tests for the testbed_evaluations endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.testbed.utils_databases.get_client_database") + @patch("server.api.v1.testbed.utils_testbed.get_evaluations") + async def test_testbed_evaluations_returns_list(self, mock_get_evals, mock_get_db, mock_db_connection): + """testbed_evaluations should return list of evaluations.""" + mock_db = MagicMock() + mock_db.connection = mock_db_connection + mock_get_db.return_value = mock_db + + mock_evals = [ + Evaluation(eid="EV001", evaluated="2024-01-01", correctness=0.85), + Evaluation(eid="EV002", evaluated="2024-01-02", correctness=0.90), + ] + mock_get_evals.return_value = mock_evals + + result = await testbed.testbed_evaluations(tid="ts001", client="test_client") + + assert result == mock_evals + mock_get_evals.assert_called_once_with(db_conn=mock_db_connection, tid="TS001") + + @pytest.mark.asyncio + @patch("server.api.v1.testbed.utils_databases.get_client_database") + @patch("server.api.v1.testbed.utils_testbed.get_evaluations") + async def test_testbed_evaluations_uppercases_tid(self, mock_get_evals, mock_get_db, mock_db_connection): + """testbed_evaluations should uppercase the tid.""" + mock_db = MagicMock() + mock_db.connection = mock_db_connection + mock_get_db.return_value = mock_db + mock_get_evals.return_value = [] + + await testbed.testbed_evaluations(tid="lowercase", client="test_client") + + mock_get_evals.assert_called_once_with(db_conn=mock_db_connection, tid="LOWERCASE") + + +class TestTestbedEvaluation: + """Tests for the testbed_evaluation endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.testbed.utils_databases.get_client_database") + @patch("server.api.v1.testbed.utils_testbed.process_report") + async def test_testbed_evaluation_returns_report(self, mock_process_report, mock_get_db, mock_db_connection): + """testbed_evaluation should return evaluation report.""" + mock_db = MagicMock() + mock_db.connection = mock_db_connection + mock_get_db.return_value = mock_db + + mock_report = MagicMock(spec=EvaluationReport) + mock_process_report.return_value = mock_report + + result = await testbed.testbed_evaluation(eid="ev001", client="test_client") + + assert result == mock_report + mock_process_report.assert_called_once_with(db_conn=mock_db_connection, eid="EV001") + + +class TestTestbedTestsetQa: + """Tests for the testbed_testset_qa endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.testbed.utils_databases.get_client_database") + @patch("server.api.v1.testbed.utils_testbed.get_testset_qa") + async def test_testbed_testset_qa_returns_data(self, mock_get_qa, mock_get_db, mock_db_connection): + """testbed_testset_qa should return Q&A data.""" + mock_db = MagicMock() + mock_db.connection = mock_db_connection + mock_get_db.return_value = mock_db + + mock_qa = TestSetQA(qa_data=[{"question": "Q1", "answer": "A1"}]) + mock_get_qa.return_value = mock_qa + + result = await testbed.testbed_testset_qa(tid="ts001", client="test_client") + + assert result == mock_qa + mock_get_qa.assert_called_once_with(db_conn=mock_db_connection, tid="TS001") + + +class TestTestbedDeleteTestset: + """Tests for the testbed_delete_testset endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.testbed.utils_databases.get_client_database") + @patch("server.api.v1.testbed.utils_testbed.delete_qa") + async def test_testbed_delete_testset_success(self, mock_delete_qa, mock_get_db, mock_db_connection): + """testbed_delete_testset should delete and return success.""" + mock_db = MagicMock() + mock_db.connection = mock_db_connection + mock_get_db.return_value = mock_db + mock_delete_qa.return_value = None + + result = await testbed.testbed_delete_testset(tid="ts001", client="test_client") + + assert result.status_code == 200 + mock_delete_qa.assert_called_once_with(mock_db_connection, "TS001") + + +class TestTestbedUpsertTestsets: + """Tests for the testbed_upsert_testsets endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.testbed.utils_databases.get_client_database") + @patch("server.api.v1.testbed.utils_testbed.jsonl_to_json_content") + @patch("server.api.v1.testbed.utils_testbed.upsert_qa") + @patch("server.api.v1.testbed.testbed_testset_qa") + async def test_testbed_upsert_testsets_success( + self, mock_testset_qa, mock_upsert, mock_jsonl, mock_get_db, mock_db_connection + ): + """testbed_upsert_testsets should upload and return Q&A.""" + mock_db = MagicMock() + mock_db.connection = mock_db_connection + mock_get_db.return_value = mock_db + mock_jsonl.return_value = [{"question": "Q1", "answer": "A1"}] + mock_upsert.return_value = "TS001" + mock_testset_qa.return_value = TestSetQA(qa_data=[{"question": "Q1"}]) + + mock_file = UploadFile(file=BytesIO(b'{"question": "Q1"}'), filename="test.jsonl") + + result = await testbed.testbed_upsert_testsets( + files=[mock_file], name="Test Set", tid=None, client="test_client" + ) + + assert isinstance(result, TestSetQA) + mock_db_connection.commit.assert_called_once() + + @pytest.mark.asyncio + @patch("server.api.v1.testbed.utils_databases.get_client_database") + @patch("server.api.v1.testbed.utils_testbed.jsonl_to_json_content") + async def test_testbed_upsert_testsets_handles_exception(self, mock_jsonl, mock_get_db, mock_db_connection): + """testbed_upsert_testsets should raise 500 on exception.""" + mock_db = MagicMock() + mock_db.connection = mock_db_connection + mock_get_db.return_value = mock_db + mock_jsonl.side_effect = Exception("Parse error") + + mock_file = UploadFile(file=BytesIO(b"invalid"), filename="test.jsonl") + + with pytest.raises(HTTPException) as exc_info: + await testbed.testbed_upsert_testsets(files=[mock_file], name="Test", tid=None, client="test_client") + + assert exc_info.value.status_code == 500 + + +class TestHandleTestsetError: + """Tests for the _handle_testset_error helper function.""" + + def test_handle_testset_error_key_error_columns(self, tmp_path): + """_handle_testset_error should raise 400 for column KeyError.""" + ex = KeyError("None of ['col1'] are in the columns") + + with pytest.raises(HTTPException) as exc_info: + testbed._handle_testset_error(ex, tmp_path, "test-model") + + assert exc_info.value.status_code == 400 + assert "test-model" in str(exc_info.value.detail) + + def test_handle_testset_error_value_error(self, tmp_path): + """_handle_testset_error should raise 400 for ValueError.""" + ex = ValueError("Invalid value") + + with pytest.raises(HTTPException) as exc_info: + testbed._handle_testset_error(ex, tmp_path, "test-model") + + assert exc_info.value.status_code == 400 + + def test_handle_testset_error_api_connection_error(self, tmp_path): + """_handle_testset_error should raise 424 for API connection error.""" + ex = litellm.APIConnectionError(message="Connection failed", llm_provider="openai", model="gpt-4") + + with pytest.raises(HTTPException) as exc_info: + testbed._handle_testset_error(ex, tmp_path, "test-model") + + assert exc_info.value.status_code == 424 + + def test_handle_testset_error_unknown_exception(self, tmp_path): + """_handle_testset_error should raise 500 for unknown exceptions.""" + ex = RuntimeError("Unknown error") + + with pytest.raises(HTTPException) as exc_info: + testbed._handle_testset_error(ex, tmp_path, "test-model") + + assert exc_info.value.status_code == 500 + + def test_handle_testset_error_other_key_error(self, tmp_path): + """_handle_testset_error should re-raise other KeyErrors.""" + ex = KeyError("some_other_key") + + with pytest.raises(KeyError): + testbed._handle_testset_error(ex, tmp_path, "test-model") + + +class TestTestbedGenerateQa: + """Tests for the testbed_generate_qa endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.testbed.utils_oci.get") + async def test_testbed_generate_qa_raises_400_on_value_error(self, mock_oci_get): + """testbed_generate_qa should raise 400 on ValueError.""" + mock_oci_get.side_effect = ValueError("Invalid OCI config") + + mock_file = UploadFile(file=BytesIO(b"content"), filename="test.txt") + + with pytest.raises(HTTPException) as exc_info: + await testbed.testbed_generate_qa( + files=[mock_file], + name="Test", + ll_model="gpt-4", + embed_model="text-embedding-3", + questions=2, + client="test_client", + ) + + assert exc_info.value.status_code == 400 + + +class TestRouterConfiguration: + """Tests for router configuration.""" + + def test_auth_router_exists(self): + """The auth router should be defined.""" + assert hasattr(testbed, "auth") + + def test_auth_router_has_routes(self): + """The auth router should have registered routes.""" + routes = [route.path for route in testbed.auth.routes] + + assert "/testsets" in routes + assert "/evaluations" in routes + assert "/evaluation" in routes + assert "/testset_qa" in routes + assert "/testset_delete/{tid}" in routes + assert "/testset_load" in routes + assert "/testset_generate" in routes + assert "/evaluate" in routes + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured.""" + assert hasattr(testbed, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert testbed.logger.name == "endpoints.v1.testbed" From 0cf7353e30dcee14347b6606602e7b5e55696778 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 14:26:37 +0000 Subject: [PATCH 07/20] Updated Tests --- src/server/api/v1/databases.py | 3 +- test/conftest.py | 174 ++-------- test/db_fixtures.py | 209 ++++++++++++ test/integration/__init__.py | 6 + test/integration/server/__init__.py | 6 + test/integration/server/api/__init__.py | 6 + test/integration/server/api/conftest.py | 195 +++++++++++ test/integration/server/api/utils/__init__.py | 6 + test/integration/server/api/v1/__init__.py | 6 + .../server/api/v1/test_databases.py | 153 +++++++++ test/integration/server/api/v1/test_models.py | 271 ++++++++++++++++ test/integration/server/api/v1/test_oci.py | 224 +++++++++++++ test/integration/server/api/v1/test_probes.py | 74 +++++ .../server/api/v1/test_settings.py | 307 ++++++++++++++++++ .../server/api/utils/test_utils_databases.py | 155 +++++++++ test/unit/server/api/v1/test_v1_databases.py | 108 +++++- .../api/utils/test_utils_databases_crud.py | 118 +------ .../utils/test_utils_databases_functions.py | 224 +------------ .../unit/api/utils/test_utils_models.py | 200 ++---------- tests/server/unit/api/utils/test_utils_oci.py | 197 +---------- .../unit/api/utils/test_utils_settings.py | 225 ++----------- 21 files changed, 1833 insertions(+), 1034 deletions(-) create mode 100644 test/db_fixtures.py create mode 100644 test/integration/__init__.py create mode 100644 test/integration/server/__init__.py create mode 100644 test/integration/server/api/__init__.py create mode 100644 test/integration/server/api/conftest.py create mode 100644 test/integration/server/api/utils/__init__.py create mode 100644 test/integration/server/api/v1/__init__.py create mode 100644 test/integration/server/api/v1/test_databases.py create mode 100644 test/integration/server/api/v1/test_models.py create mode 100644 test/integration/server/api/v1/test_oci.py create mode 100644 test/integration/server/api/v1/test_probes.py create mode 100644 test/integration/server/api/v1/test_settings.py diff --git a/src/server/api/v1/databases.py b/src/server/api/v1/databases.py index d17ffcf4..acf33168 100644 --- a/src/server/api/v1/databases.py +++ b/src/server/api/v1/databases.py @@ -39,6 +39,7 @@ async def databases_list() -> list[schema.Database]: return database_objects + @auth.get( "/{name}", description="Get single database configuration and vector storage", @@ -101,7 +102,7 @@ async def databases_update( database_objects = utils_databases.get_databases() for other_db in database_objects: if other_db.name != name and other_db.connection: - other_db.set_connection(utils_databases.disconnect(db.connection)) + other_db.set_connection(utils_databases.disconnect(other_db.connection)) other_db.connected = False return db diff --git a/test/conftest.py b/test/conftest.py index e409c62e..c25db9f6 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -3,159 +3,25 @@ Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. Pytest fixtures for unit tests with real Oracle database. -Adapts the Docker container pattern from tests/conftest.py. -""" - -# pylint: disable=consider-using-with -# pylint: disable=redefined-outer-name -# Pytest fixtures use parameter injection where fixture names match parameters - -import time -import shutil -from pathlib import Path -from typing import Generator, Optional -from contextlib import contextmanager - -import pytest -import oracledb -import docker -from docker.errors import DockerException -from docker.models.containers import Container - - -# Test database configuration - matches tests/conftest.py -TEST_CONFIG = { - "db_username": "PYTEST", - "db_password": "OrA_41_3xPl0d3r", - "db_dsn": "//localhost:1525/FREEPDB1", -} - - -def wait_for_container_ready(container: Container, ready_output: str, since: Optional[int] = None) -> None: - """Wait for container to be ready by checking its logs with exponential backoff.""" - start_time = time.time() - retry_interval = 2 - - while time.time() - start_time < 120: # 2 minute timeout - try: - logs = container.logs(tail=100, since=since).decode("utf-8") - if ready_output in logs: - return - except DockerException as e: - container.remove(force=True) - raise DockerException(f"Failed to get container logs: {str(e)}") from e - - time.sleep(retry_interval) - retry_interval = min(retry_interval * 2, 10) # Exponential backoff, max 10 seconds - - container.remove(force=True) - raise TimeoutError("Container did not become ready within timeout") - - -@contextmanager -def temp_sql_setup(): - """Context manager for temporary SQL setup files.""" - temp_dir = Path("test/db_startup_temp") - try: - temp_dir.mkdir(exist_ok=True) - sql_content = f""" - alter system set vector_memory_size=512M scope=spfile; - - alter session set container=FREEPDB1; - CREATE TABLESPACE IF NOT EXISTS USERS DATAFILE '/opt/oracle/oradata/FREE/FREEPDB1/users_01.dbf' SIZE 100M; - CREATE USER IF NOT EXISTS "{TEST_CONFIG["db_username"]}" IDENTIFIED BY {TEST_CONFIG["db_password"]} - DEFAULT TABLESPACE "USERS" - TEMPORARY TABLESPACE "TEMP"; - GRANT "DB_DEVELOPER_ROLE" TO "{TEST_CONFIG["db_username"]}"; - ALTER USER "{TEST_CONFIG["db_username"]}" DEFAULT ROLE ALL; - ALTER USER "{TEST_CONFIG["db_username"]}" QUOTA UNLIMITED ON USERS; - - EXIT; - """ - - temp_sql_file = temp_dir / "01_db_user.sql" - temp_sql_file.write_text(sql_content, encoding="UTF-8") - yield temp_dir - finally: - if temp_dir.exists(): - shutil.rmtree(temp_dir) +Re-exports shared database fixtures from test.db_fixtures. +""" -@pytest.fixture(scope="session") -def db_container() -> Generator[Container, None, None]: - """Create and manage an Oracle database container for testing.""" - db_client = docker.from_env() - container = None - - try: - with temp_sql_setup() as temp_dir: - container = db_client.containers.run( - "container-registry.oracle.com/database/free:latest-lite", - environment={ - "ORACLE_PWD": TEST_CONFIG["db_password"], - "ORACLE_PDB": TEST_CONFIG["db_dsn"].rsplit("/", maxsplit=1)[-1], # FREEPDB1 - }, - ports={"1521/tcp": int(TEST_CONFIG["db_dsn"].split(":")[1].split("/")[0])}, # 1525 - volumes={str(temp_dir.absolute()): {"bind": "/opt/oracle/scripts/startup", "mode": "ro"}}, - detach=True, - ) - - # Wait for database to be ready - wait_for_container_ready(container, "DATABASE IS READY TO USE!") - - # Restart container to apply vector_memory_size - container.restart() - restart_time = int(time.time()) - wait_for_container_ready(container, "DATABASE IS READY TO USE!", since=restart_time) - - yield container - - except DockerException as e: - if container: - container.remove(force=True) - raise DockerException(f"Docker operation failed: {str(e)}") from e - - finally: - if container: - try: - container.stop(timeout=30) - container.remove() - except DockerException as e: - print(f"Warning: Failed to cleanup database container: {str(e)}") - - -@pytest.fixture(scope="session") -def db_connection(db_container) -> Generator[oracledb.Connection, None, None]: - """Session-scoped real Oracle database connection. - - Depends on db_container to ensure database is running. - Fails explicitly if connection cannot be established. - """ - # pylint: disable=unused-argument - conn = oracledb.connect( - user=TEST_CONFIG["db_username"], - password=TEST_CONFIG["db_password"], - dsn=TEST_CONFIG["db_dsn"], - ) - yield conn - conn.close() - - -@pytest.fixture -def db_transaction(db_connection) -> Generator[oracledb.Connection, None, None]: - """Transaction isolation for each test using savepoints. - - Creates a savepoint before each test and rolls back after, - ensuring tests don't affect each other's database state. - - Note: This is NOT autouse - tests must explicitly request it - to get transaction isolation. This allows tests that don't - need database access to run without the overhead. - """ - cursor = db_connection.cursor() - cursor.execute("SAVEPOINT test_savepoint") - - yield db_connection - - cursor.execute("ROLLBACK TO SAVEPOINT test_savepoint") - cursor.close() +# Re-export shared fixtures for pytest discovery +from test.db_fixtures import ( + TEST_DB_CONFIG, + db_container, + db_connection, + db_transaction, +) + +# Expose TEST_CONFIG alias for backwards compatibility +TEST_CONFIG = TEST_DB_CONFIG + +__all__ = [ + "TEST_CONFIG", + "TEST_DB_CONFIG", + "db_container", + "db_connection", + "db_transaction", +] diff --git a/test/db_fixtures.py b/test/db_fixtures.py new file mode 100644 index 00000000..51609173 --- /dev/null +++ b/test/db_fixtures.py @@ -0,0 +1,209 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Shared database fixtures and utilities for tests. + +This module provides common database container management functions +used by both unit and integration tests. +""" + +# pylint: disable=redefined-outer-name +# Pytest fixtures use parameter injection where fixture names match parameters + +import time +import shutil +from pathlib import Path +from typing import Generator, Optional +from contextlib import contextmanager + +import pytest +import oracledb +import docker +from docker.errors import DockerException +from docker.models.containers import Container + + +# Test database configuration - shared across all tests +TEST_DB_CONFIG = { + "db_username": "PYTEST", + "db_password": "OrA_41_3xPl0d3r", + "db_dsn": "//localhost:1525/FREEPDB1", +} + + +def wait_for_container_ready( + container: Container, + ready_output: str, + since: Optional[int] = None, + timeout: int = 120, +) -> None: + """Wait for container to be ready by checking its logs with exponential backoff. + + Args: + container: Docker container to monitor + ready_output: String to look for in logs indicating readiness + since: Unix timestamp to filter logs from (optional) + timeout: Maximum seconds to wait (default 120) + + Raises: + TimeoutError: If container doesn't become ready within timeout + DockerException: If there's an error getting container logs + """ + start_time = time.time() + retry_interval = 2 + + while time.time() - start_time < timeout: + try: + logs = container.logs(tail=100, since=since).decode("utf-8") + if ready_output in logs: + return + except DockerException as e: + container.remove(force=True) + raise DockerException(f"Failed to get container logs: {str(e)}") from e + + time.sleep(retry_interval) + retry_interval = min(retry_interval * 2, 10) + + container.remove(force=True) + raise TimeoutError("Container did not become ready within timeout") + + +@contextmanager +def temp_sql_setup(temp_dir_path: str = "test/db_startup_temp"): + """Context manager for temporary SQL setup files. + + Creates a temporary directory with SQL initialization scripts + for the Oracle container. + + Args: + temp_dir_path: Path for temporary directory + + Yields: + Path object to the temporary directory + """ + temp_dir = Path(temp_dir_path) + try: + temp_dir.mkdir(exist_ok=True) + sql_content = f""" + alter system set vector_memory_size=512M scope=spfile; + + alter session set container=FREEPDB1; + CREATE TABLESPACE IF NOT EXISTS USERS DATAFILE '/opt/oracle/oradata/FREE/FREEPDB1/users_01.dbf' SIZE 100M; + CREATE USER IF NOT EXISTS "{TEST_DB_CONFIG["db_username"]}" IDENTIFIED BY {TEST_DB_CONFIG["db_password"]} + DEFAULT TABLESPACE "USERS" + TEMPORARY TABLESPACE "TEMP"; + GRANT "DB_DEVELOPER_ROLE" TO "{TEST_DB_CONFIG["db_username"]}"; + ALTER USER "{TEST_DB_CONFIG["db_username"]}" DEFAULT ROLE ALL; + ALTER USER "{TEST_DB_CONFIG["db_username"]}" QUOTA UNLIMITED ON USERS; + + EXIT; + """ + + temp_sql_file = temp_dir / "01_db_user.sql" + temp_sql_file.write_text(sql_content, encoding="UTF-8") + yield temp_dir + finally: + if temp_dir.exists(): + shutil.rmtree(temp_dir) + + +def create_db_container(temp_dir_name: str = "test/db_startup_temp") -> Generator[Container, None, None]: + """Create and manage an Oracle database container for testing. + + This generator function handles the full lifecycle of a Docker-based + Oracle database container for testing purposes. + + Args: + temp_dir_name: Path for temporary SQL setup files + + Yields: + Docker Container object for the running database + + Raises: + DockerException: If Docker operations fail + """ + db_client = docker.from_env() + container = None + + try: + with temp_sql_setup(temp_dir_name) as temp_dir: + container = db_client.containers.run( + "container-registry.oracle.com/database/free:latest-lite", + environment={ + "ORACLE_PWD": TEST_DB_CONFIG["db_password"], + "ORACLE_PDB": TEST_DB_CONFIG["db_dsn"].rsplit("/", maxsplit=1)[-1], + }, + ports={"1521/tcp": int(TEST_DB_CONFIG["db_dsn"].split(":")[1].split("/")[0])}, + volumes={str(temp_dir.absolute()): {"bind": "/opt/oracle/scripts/startup", "mode": "ro"}}, + detach=True, + ) + + # Wait for database to be ready + wait_for_container_ready(container, "DATABASE IS READY TO USE!") + + # Restart container to apply vector_memory_size + container.restart() + restart_time = int(time.time()) + wait_for_container_ready(container, "DATABASE IS READY TO USE!", since=restart_time) + + yield container + + except DockerException as e: + if container: + container.remove(force=True) + raise DockerException(f"Docker operation failed: {str(e)}") from e + + finally: + if container: + try: + container.stop(timeout=30) + container.remove() + except DockerException as e: + print(f"Warning: Failed to cleanup database container: {str(e)}") + + +@pytest.fixture(scope="session") +def db_container() -> Generator[Container, None, None]: + """Pytest fixture for Oracle database container. + + Session-scoped fixture that creates and manages an Oracle database + container for the duration of the test session. + """ + yield from create_db_container() + + +@pytest.fixture(scope="session") +def db_connection(db_container) -> Generator[oracledb.Connection, None, None]: + """Session-scoped real Oracle database connection. + + Depends on db_container to ensure database is running. + """ + _ = db_container # Ensure container is running + conn = oracledb.connect( + user=TEST_DB_CONFIG["db_username"], + password=TEST_DB_CONFIG["db_password"], + dsn=TEST_DB_CONFIG["db_dsn"], + ) + yield conn + conn.close() + + +@pytest.fixture +def db_transaction(db_connection) -> Generator[oracledb.Connection, None, None]: + """Transaction isolation for each test using savepoints. + + Creates a savepoint before each test and rolls back after, + ensuring tests don't affect each other's database state. + + Note: DDL operations (CREATE TABLE, etc.) cause implicit commits + in Oracle, which will invalidate the savepoint. Tests with DDL + should use mocks or handle cleanup manually. + """ + cursor = db_connection.cursor() + cursor.execute("SAVEPOINT test_savepoint") + + yield db_connection + + cursor.execute("ROLLBACK TO SAVEPOINT test_savepoint") + cursor.close() diff --git a/test/integration/__init__.py b/test/integration/__init__.py new file mode 100644 index 00000000..2577126f --- /dev/null +++ b/test/integration/__init__.py @@ -0,0 +1,6 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Integration tests package. +""" diff --git a/test/integration/server/__init__.py b/test/integration/server/__init__.py new file mode 100644 index 00000000..a242d937 --- /dev/null +++ b/test/integration/server/__init__.py @@ -0,0 +1,6 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Server integration tests package. +""" diff --git a/test/integration/server/api/__init__.py b/test/integration/server/api/__init__.py new file mode 100644 index 00000000..c4b92db3 --- /dev/null +++ b/test/integration/server/api/__init__.py @@ -0,0 +1,6 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Server API integration tests package. +""" diff --git a/test/integration/server/api/conftest.py b/test/integration/server/api/conftest.py new file mode 100644 index 00000000..045d68e0 --- /dev/null +++ b/test/integration/server/api/conftest.py @@ -0,0 +1,195 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Pytest fixtures for server API integration tests. + +Integration tests use a real FastAPI TestClient with the actual application, +testing the full request/response cycle through the API layer. + +Note: db_container fixture is inherited from test/conftest.py - do not import here. +""" + +# pylint: disable=redefined-outer-name +# Pytest fixtures use parameter injection where fixture names match parameters + +import os +import asyncio +from typing import Generator + +from test.db_fixtures import TEST_DB_CONFIG + +import pytest +from fastapi.testclient import TestClient + +from common.schema import Database, Model +from server.bootstrap.bootstrap import DATABASE_OBJECTS, MODEL_OBJECTS, SETTINGS_OBJECTS + + +# Clear environment variables that could interfere with tests +# This must happen before importing application modules +API_VARS = ["API_SERVER_KEY", "API_SERVER_URL", "API_SERVER_PORT"] +DB_VARS = ["DB_USERNAME", "DB_PASSWORD", "DB_DSN", "DB_WALLET_PASSWORD", "TNS_ADMIN"] +MODEL_VARS = ["ON_PREM_OLLAMA_URL", "ON_PREM_HF_URL", "OPENAI_API_KEY", "PPLX_API_KEY", "COHERE_API_KEY"] +for env_var in [*API_VARS, *DB_VARS, *MODEL_VARS, *[var for var in os.environ if var.startswith("OCI_")]]: + os.environ.pop(env_var, None) + +# Test configuration - extends shared DB config with integration-specific settings +TEST_CONFIG = { + "client": "integration_test", + "auth_token": "integration-test-token", + **TEST_DB_CONFIG, +} + +# Set environment variables for test server +os.environ["CONFIG_FILE"] = "/non/existent/path/config.json" # Use empty config +os.environ["OCI_CLI_CONFIG_FILE"] = "/non/existent/path" # Prevent OCI config pickup +os.environ["API_SERVER_KEY"] = TEST_CONFIG["auth_token"] + + +################################################# +# Authentication Headers +################################################# +@pytest.fixture +def auth_headers(): + """Return common header configurations for testing.""" + return { + "no_auth": {}, + "invalid_auth": {"Authorization": "Bearer invalid-token", "client": TEST_CONFIG["client"]}, + "valid_auth": {"Authorization": f"Bearer {TEST_CONFIG['auth_token']}", "client": TEST_CONFIG["client"]}, + } + + +################################################# +# FastAPI Test Client +################################################# +@pytest.fixture(scope="session") +def app(): + """Create the FastAPI application for testing. + + This fixture creates the actual FastAPI app using the same factory + function as the production server (launch_server.create_app). + + Import is done inside the fixture to ensure environment variables + are set before any application modules are loaded. + """ + # pylint: disable=import-outside-toplevel + from launch_server import create_app + + return asyncio.run(create_app()) + + +@pytest.fixture(scope="session") +def client(app) -> Generator[TestClient, None, None]: + """Create a TestClient for the FastAPI app. + + The TestClient allows making HTTP requests to the app without + starting a real server, enabling fast integration testing. + """ + with TestClient(app) as test_client: + yield test_client + + +################################################# +# Test Data Helpers +################################################# +@pytest.fixture +def test_db_payload(): + """Get standard test database payload for integration tests.""" + return { + "user": TEST_CONFIG["db_username"], + "password": TEST_CONFIG["db_password"], + "dsn": TEST_CONFIG["db_dsn"], + } + + +@pytest.fixture +def sample_model_payload(): + """Sample model configuration for testing.""" + return { + "id": "test-model", + "type": "ll", + "provider": "openai", + "enabled": True, + } + + +@pytest.fixture +def sample_settings_payload(): + """Sample settings configuration for testing.""" + return { + "client": TEST_CONFIG["client"], + "ll_model": { + "model": "gpt-4o-mini", + "temperature": 0.7, + "max_tokens": 4096, + "chat_history": True, + }, + } + + +################################################# +# Schema Factory Fixtures +################################################# +@pytest.fixture +def make_database(): + """Factory fixture for creating Database objects.""" + def _make_database(**kwargs): + defaults = { + "name": "TEST_DB", + "user": "test_user", + "password": "test_password", + "dsn": "localhost:1521/TEST", + } + defaults.update(kwargs) + return Database(**defaults) + return _make_database + + +@pytest.fixture +def make_model(): + """Factory fixture for creating Model objects.""" + def _make_model(**kwargs): + defaults = { + "id": "test-model", + "type": "ll", + "provider": "openai", + "enabled": True, + } + defaults.update(kwargs) + return Model(**defaults) + return _make_model + + +################################################# +# State Management Helpers +################################################# +@pytest.fixture +def db_objects_manager(): + """Fixture to manage DATABASE_OBJECTS save/restore operations. + + This fixture saves the current state of DATABASE_OBJECTS before each test + and restores it afterward, ensuring tests don't affect each other. + """ + original_db_objects = DATABASE_OBJECTS.copy() + yield DATABASE_OBJECTS + DATABASE_OBJECTS.clear() + DATABASE_OBJECTS.extend(original_db_objects) + + +@pytest.fixture +def model_objects_manager(): + """Fixture to manage MODEL_OBJECTS save/restore operations.""" + original_model_objects = MODEL_OBJECTS.copy() + yield MODEL_OBJECTS + MODEL_OBJECTS.clear() + MODEL_OBJECTS.extend(original_model_objects) + + +@pytest.fixture +def settings_objects_manager(): + """Fixture to manage SETTINGS_OBJECTS save/restore operations.""" + original_settings_objects = SETTINGS_OBJECTS.copy() + yield SETTINGS_OBJECTS + SETTINGS_OBJECTS.clear() + SETTINGS_OBJECTS.extend(original_settings_objects) diff --git a/test/integration/server/api/utils/__init__.py b/test/integration/server/api/utils/__init__.py new file mode 100644 index 00000000..37340b95 --- /dev/null +++ b/test/integration/server/api/utils/__init__.py @@ -0,0 +1,6 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Server API utils integration tests package. +""" diff --git a/test/integration/server/api/v1/__init__.py b/test/integration/server/api/v1/__init__.py new file mode 100644 index 00000000..d55308b1 --- /dev/null +++ b/test/integration/server/api/v1/__init__.py @@ -0,0 +1,6 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Server API v1 integration tests package. +""" diff --git a/test/integration/server/api/v1/test_databases.py b/test/integration/server/api/v1/test_databases.py new file mode 100644 index 00000000..847ef735 --- /dev/null +++ b/test/integration/server/api/v1/test_databases.py @@ -0,0 +1,153 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Integration tests for server/api/v1/databases.py + +Tests the database configuration endpoints through the full API stack. +These endpoints require authentication. +""" + + +class TestAuthentication: + """Integration tests for authentication on database endpoints.""" + + def test_databases_list_requires_auth(self, client): + """GET /v1/databases should require authentication.""" + response = client.get("/v1/databases") + + assert response.status_code == 401 # No auth header = Unauthorized + + def test_databases_list_rejects_invalid_token(self, client, auth_headers): + """GET /v1/databases should reject invalid tokens.""" + response = client.get("/v1/databases", headers=auth_headers["invalid_auth"]) + + assert response.status_code == 401 + + def test_databases_list_accepts_valid_token(self, client, auth_headers): + """GET /v1/databases should accept valid tokens.""" + response = client.get("/v1/databases", headers=auth_headers["valid_auth"]) + + assert response.status_code == 200 + + +class TestDatabasesList: + """Integration tests for the databases list endpoint.""" + + def test_databases_list_returns_list(self, client, auth_headers): + """GET /v1/databases should return a list of databases.""" + response = client.get("/v1/databases", headers=auth_headers["valid_auth"]) + + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + + def test_databases_list_contains_default(self, client, auth_headers): + """GET /v1/databases should contain a DEFAULT database.""" + response = client.get("/v1/databases", headers=auth_headers["valid_auth"]) + + data = response.json() + # There should be at least one database (DEFAULT is created by bootstrap) + # If no config file, the list may be empty or contain DEFAULT + assert isinstance(data, list) + + def test_databases_list_returns_database_schema(self, client, auth_headers, db_objects_manager, make_database): + """GET /v1/databases should return databases with correct schema.""" + # Ensure there's at least one database for testing + if not db_objects_manager: + db_objects_manager.append(make_database()) + + response = client.get("/v1/databases", headers=auth_headers["valid_auth"]) + + assert response.status_code == 200 + data = response.json() + if data: + db = data[0] + assert "name" in db + assert "user" in db + assert "dsn" in db + assert "connected" in db + + +class TestDatabasesGet: + """Integration tests for the single database get endpoint.""" + + def test_databases_get_requires_auth(self, client): + """GET /v1/databases/{name} should require authentication.""" + response = client.get("/v1/databases/DEFAULT") + + assert response.status_code == 401 + + def test_databases_get_returns_404_for_unknown(self, client, auth_headers): + """GET /v1/databases/{name} should return 404 for unknown database.""" + response = client.get("/v1/databases/NONEXISTENT_DB", headers=auth_headers["valid_auth"]) + + assert response.status_code == 404 + + def test_databases_get_returns_database(self, client, auth_headers, db_objects_manager, make_database): + """GET /v1/databases/{name} should return the specified database.""" + # Ensure there's a test database + test_db = make_database(name="INTEGRATION_TEST_DB") + db_objects_manager.append(test_db) + + response = client.get("/v1/databases/INTEGRATION_TEST_DB", headers=auth_headers["valid_auth"]) + + assert response.status_code == 200 + data = response.json() + assert data["name"] == "INTEGRATION_TEST_DB" + + +class TestDatabasesUpdate: + """Integration tests for the database update endpoint.""" + + def test_databases_update_requires_auth(self, client): + """PATCH /v1/databases/{name} should require authentication.""" + response = client.patch("/v1/databases/DEFAULT", json={"user": "test"}) + + assert response.status_code == 401 + + def test_databases_update_returns_404_for_unknown(self, client, auth_headers): + """PATCH /v1/databases/{name} should return 404 for unknown database.""" + response = client.patch( + "/v1/databases/NONEXISTENT_DB", + headers=auth_headers["valid_auth"], + json={"user": "test", "password": "test", "dsn": "localhost:1521/TEST"}, + ) + + assert response.status_code == 404 + + def test_databases_update_validates_connection(self, client, auth_headers, db_objects_manager, make_database): + """PATCH /v1/databases/{name} should validate connection details.""" + # Add a test database + test_db = make_database(name="UPDATE_TEST_DB") + db_objects_manager.append(test_db) + + # Try to update with invalid connection details (no real DB running) + response = client.patch( + "/v1/databases/UPDATE_TEST_DB", + headers=auth_headers["valid_auth"], + json={"user": "invalid", "password": "invalid", "dsn": "localhost:9999/INVALID"}, + ) + + # Should fail because it tries to connect + assert response.status_code in [400, 401, 404, 503] + + def test_databases_update_connects_to_real_db( + self, client, auth_headers, db_objects_manager, db_container, test_db_payload, make_database + ): + """PATCH /v1/databases/{name} should connect to real database.""" + _ = db_container # Ensure container is running + # Add a test database + test_db = make_database(name="REAL_DB_TEST", user="placeholder", password="placeholder", dsn="placeholder") + db_objects_manager.append(test_db) + + response = client.patch( + "/v1/databases/REAL_DB_TEST", + headers=auth_headers["valid_auth"], + json=test_db_payload, + ) + + assert response.status_code == 200 + data = response.json() + assert data["connected"] is True + assert data["user"] == test_db_payload["user"] diff --git a/test/integration/server/api/v1/test_models.py b/test/integration/server/api/v1/test_models.py new file mode 100644 index 00000000..74cbd11b --- /dev/null +++ b/test/integration/server/api/v1/test_models.py @@ -0,0 +1,271 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Integration tests for server/api/v1/models.py + +Tests the model configuration endpoints through the full API stack. +These endpoints require authentication. +""" + + +class TestAuthentication: + """Integration tests for authentication on model endpoints.""" + + def test_models_list_requires_auth(self, client): + """GET /v1/models should require authentication.""" + response = client.get("/v1/models") + + assert response.status_code == 401 + + def test_models_list_rejects_invalid_token(self, client, auth_headers): + """GET /v1/models should reject invalid tokens.""" + response = client.get("/v1/models", headers=auth_headers["invalid_auth"]) + + assert response.status_code == 401 + + def test_models_list_accepts_valid_token(self, client, auth_headers): + """GET /v1/models should accept valid tokens.""" + response = client.get("/v1/models", headers=auth_headers["valid_auth"]) + + assert response.status_code == 200 + + +class TestModelsList: + """Integration tests for the models list endpoint.""" + + def test_models_list_returns_list(self, client, auth_headers): + """GET /v1/models should return a list of models.""" + response = client.get("/v1/models", headers=auth_headers["valid_auth"]) + + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + + def test_models_list_returns_enabled_only_by_default(self, client, auth_headers): + """GET /v1/models should return only enabled models by default.""" + response = client.get("/v1/models", headers=auth_headers["valid_auth"]) + + data = response.json() + for model in data: + assert model["enabled"] is True + + def test_models_list_with_include_disabled(self, client, auth_headers): + """GET /v1/models?include_disabled=true should include disabled models.""" + response = client.get( + "/v1/models", + headers=auth_headers["valid_auth"], + params={"include_disabled": True}, + ) + + assert response.status_code == 200 + data = response.json() + # Should have at least some models (bootstrap loads defaults) + assert isinstance(data, list) + + def test_models_list_filter_by_type_ll(self, client, auth_headers): + """GET /v1/models?model_type=ll should return only LL models.""" + response = client.get( + "/v1/models", + headers=auth_headers["valid_auth"], + params={"model_type": "ll", "include_disabled": True}, + ) + + assert response.status_code == 200 + data = response.json() + for model in data: + assert model["type"] == "ll" + + def test_models_list_filter_by_type_embed(self, client, auth_headers): + """GET /v1/models?model_type=embed should return only embed models.""" + response = client.get( + "/v1/models", + headers=auth_headers["valid_auth"], + params={"model_type": "embed", "include_disabled": True}, + ) + + assert response.status_code == 200 + data = response.json() + for model in data: + assert model["type"] == "embed" + + +class TestModelsSupported: + """Integration tests for the supported models endpoint.""" + + def test_models_supported_returns_list(self, client, auth_headers): + """GET /v1/models/supported should return supported providers.""" + response = client.get("/v1/models/supported", headers=auth_headers["valid_auth"]) + + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + + def test_models_supported_filter_by_provider(self, client, auth_headers): + """GET /v1/models/supported?model_provider=openai should filter by provider.""" + response = client.get( + "/v1/models/supported", + headers=auth_headers["valid_auth"], + params={"model_provider": "openai"}, + ) + + assert response.status_code == 200 + data = response.json() + for item in data: + assert item.get("provider") == "openai" + + def test_models_supported_filter_by_type(self, client, auth_headers): + """GET /v1/models/supported?model_type=ll should filter by type.""" + response = client.get( + "/v1/models/supported", + headers=auth_headers["valid_auth"], + params={"model_type": "ll"}, + ) + + assert response.status_code == 200 + data = response.json() + # Response is a list of provider objects with provider and models keys + assert isinstance(data, list) + # Each item should have provider and models keys + for item in data: + assert "provider" in item + assert "models" in item + + +class TestModelsGet: + """Integration tests for the single model get endpoint.""" + + def test_models_get_requires_auth(self, client): + """GET /v1/models/{provider}/{id} should require authentication.""" + response = client.get("/v1/models/openai/gpt-4o-mini") + + assert response.status_code == 401 + + def test_models_get_returns_404_for_unknown(self, client, auth_headers): + """GET /v1/models/{provider}/{id} should return 404 for unknown model.""" + response = client.get( + "/v1/models/nonexistent/nonexistent-model", + headers=auth_headers["valid_auth"], + ) + + assert response.status_code == 404 + + def test_models_get_returns_model(self, client, auth_headers, model_objects_manager, make_model): + """GET /v1/models/{provider}/{id} should return the specified model.""" + # Add a test model + test_model = make_model(id="integration-test-model") + model_objects_manager.append(test_model) + + response = client.get( + "/v1/models/openai/integration-test-model", + headers=auth_headers["valid_auth"], + ) + + assert response.status_code == 200 + data = response.json() + assert data["id"] == "integration-test-model" + assert data["provider"] == "openai" + + +class TestModelsCreate: + """Integration tests for the model create endpoint.""" + + def test_models_create_requires_auth(self, client): + """POST /v1/models should require authentication.""" + response = client.post( + "/v1/models", + json={"id": "test-model", "type": "ll", "provider": "openai", "enabled": True}, + ) + + assert response.status_code == 401 + + def test_models_create_success(self, client, auth_headers, model_objects_manager): + """POST /v1/models should create a new model.""" + # pylint: disable=unused-argument + response = client.post( + "/v1/models", + headers=auth_headers["valid_auth"], + json={"id": "new-test-model", "type": "ll", "provider": "openai", "enabled": True}, + ) + + assert response.status_code == 201 + data = response.json() + assert data["id"] == "new-test-model" + assert data["provider"] == "openai" + + def test_models_create_returns_409_for_duplicate(self, client, auth_headers, model_objects_manager, make_model): + """POST /v1/models should return 409 for duplicate model.""" + # Add existing model + existing_model = make_model(id="duplicate-model") + model_objects_manager.append(existing_model) + + response = client.post( + "/v1/models", + headers=auth_headers["valid_auth"], + json={"id": "duplicate-model", "type": "ll", "provider": "openai", "enabled": True}, + ) + + assert response.status_code == 409 + + +class TestModelsUpdate: + """Integration tests for the model update endpoint.""" + + def test_models_update_requires_auth(self, client): + """PATCH /v1/models/{provider}/{id} should require authentication.""" + response = client.patch( + "/v1/models/openai/test-model", + json={"id": "test-model", "type": "ll", "provider": "openai", "enabled": False}, + ) + + assert response.status_code == 401 + + def test_models_update_returns_404_for_unknown(self, client, auth_headers): + """PATCH /v1/models/{provider}/{id} should return 404 for unknown model.""" + response = client.patch( + "/v1/models/nonexistent/nonexistent-model", + headers=auth_headers["valid_auth"], + json={"id": "nonexistent-model", "type": "ll", "provider": "nonexistent", "enabled": False}, + ) + + assert response.status_code == 404 + + def test_models_update_success(self, client, auth_headers, model_objects_manager, make_model): + """PATCH /v1/models/{provider}/{id} should update the model.""" + # Add a test model + test_model = make_model(id="update-test-model") + model_objects_manager.append(test_model) + + response = client.patch( + "/v1/models/openai/update-test-model", + headers=auth_headers["valid_auth"], + json={"id": "update-test-model", "type": "ll", "provider": "openai", "enabled": False}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["enabled"] is False + + +class TestModelsDelete: + """Integration tests for the model delete endpoint.""" + + def test_models_delete_requires_auth(self, client): + """DELETE /v1/models/{provider}/{id} should require authentication.""" + response = client.delete("/v1/models/openai/test-model") + + assert response.status_code == 401 + + def test_models_delete_success(self, client, auth_headers, model_objects_manager, make_model): + """DELETE /v1/models/{provider}/{id} should delete the model.""" + # Add a test model to delete + test_model = make_model(id="delete-test-model") + model_objects_manager.append(test_model) + + response = client.delete( + "/v1/models/openai/delete-test-model", + headers=auth_headers["valid_auth"], + ) + + assert response.status_code == 200 + assert "deleted" in response.json()["message"].lower() diff --git a/test/integration/server/api/v1/test_oci.py b/test/integration/server/api/v1/test_oci.py new file mode 100644 index 00000000..aeed656b --- /dev/null +++ b/test/integration/server/api/v1/test_oci.py @@ -0,0 +1,224 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Integration tests for server/api/v1/oci.py + +Tests the OCI configuration endpoints through the full API stack. +These endpoints require authentication. + +Note: Most OCI operations require valid OCI credentials. Tests without +real OCI credentials will verify endpoint availability and authentication. +""" + + +class TestOciList: + """Integration tests for the OCI list endpoint.""" + + def test_oci_list_requires_auth(self, client): + """GET /v1/oci should require authentication.""" + response = client.get("/v1/oci") + + assert response.status_code == 401 + + def test_oci_list_rejects_invalid_token(self, client, auth_headers): + """GET /v1/oci should reject invalid tokens.""" + response = client.get("/v1/oci", headers=auth_headers["invalid_auth"]) + + assert response.status_code == 401 + + def test_oci_list_accepts_valid_token(self, client, auth_headers): + """GET /v1/oci should accept valid tokens.""" + response = client.get("/v1/oci", headers=auth_headers["valid_auth"]) + + # May return 200 (with configs) or 404 (no configs) + assert response.status_code in [200, 404] + + def test_oci_list_returns_list_or_404(self, client, auth_headers): + """GET /v1/oci should return a list of OCI configs or 404 if none.""" + response = client.get("/v1/oci", headers=auth_headers["valid_auth"]) + + if response.status_code == 200: + data = response.json() + assert isinstance(data, list) + else: + assert response.status_code == 404 + + +class TestOciGet: + """Integration tests for the single OCI profile get endpoint.""" + + def test_oci_get_requires_auth(self, client): + """GET /v1/oci/{auth_profile} should require authentication.""" + response = client.get("/v1/oci/DEFAULT") + + assert response.status_code == 401 + + def test_oci_get_returns_404_for_unknown(self, client, auth_headers): + """GET /v1/oci/{auth_profile} should return 404 for unknown profile.""" + response = client.get( + "/v1/oci/NONEXISTENT_PROFILE", + headers=auth_headers["valid_auth"], + ) + + assert response.status_code == 404 + + +class TestOciRegions: + """Integration tests for the OCI regions endpoint.""" + + def test_oci_regions_requires_auth(self, client): + """GET /v1/oci/regions/{auth_profile} should require authentication.""" + response = client.get("/v1/oci/regions/DEFAULT") + + assert response.status_code == 401 + + def test_oci_regions_returns_404_for_unknown_profile(self, client, auth_headers): + """GET /v1/oci/regions/{auth_profile} should return 404 for unknown profile.""" + response = client.get( + "/v1/oci/regions/NONEXISTENT_PROFILE", + headers=auth_headers["valid_auth"], + ) + + assert response.status_code == 404 + + +class TestOciGenai: + """Integration tests for the OCI GenAI models endpoint.""" + + def test_oci_genai_requires_auth(self, client): + """GET /v1/oci/genai/{auth_profile} should require authentication.""" + response = client.get("/v1/oci/genai/DEFAULT") + + assert response.status_code == 401 + + def test_oci_genai_returns_404_for_unknown_profile(self, client, auth_headers): + """GET /v1/oci/genai/{auth_profile} should return 404 for unknown profile.""" + response = client.get( + "/v1/oci/genai/NONEXISTENT_PROFILE", + headers=auth_headers["valid_auth"], + ) + + assert response.status_code == 404 + + +class TestOciCompartments: + """Integration tests for the OCI compartments endpoint.""" + + def test_oci_compartments_requires_auth(self, client): + """GET /v1/oci/compartments/{auth_profile} should require authentication.""" + response = client.get("/v1/oci/compartments/DEFAULT") + + assert response.status_code == 401 + + def test_oci_compartments_returns_404_for_unknown_profile(self, client, auth_headers): + """GET /v1/oci/compartments/{auth_profile} should return 404 for unknown profile.""" + response = client.get( + "/v1/oci/compartments/NONEXISTENT_PROFILE", + headers=auth_headers["valid_auth"], + ) + + assert response.status_code == 404 + + +class TestOciBuckets: + """Integration tests for the OCI buckets endpoint.""" + + def test_oci_buckets_requires_auth(self, client): + """GET /v1/oci/buckets/{compartment_ocid}/{auth_profile} should require authentication.""" + response = client.get("/v1/oci/buckets/ocid1.compartment.oc1..test/DEFAULT") + + assert response.status_code == 401 + + def test_oci_buckets_returns_404_for_unknown_profile(self, client, auth_headers): + """GET /v1/oci/buckets/{compartment_ocid}/{auth_profile} should return 404 for unknown profile.""" + response = client.get( + "/v1/oci/buckets/ocid1.compartment.oc1..test/NONEXISTENT_PROFILE", + headers=auth_headers["valid_auth"], + ) + + assert response.status_code == 404 + + +class TestOciObjects: + """Integration tests for the OCI bucket objects endpoint.""" + + def test_oci_objects_requires_auth(self, client): + """GET /v1/oci/objects/{bucket_name}/{auth_profile} should require authentication.""" + response = client.get("/v1/oci/objects/test-bucket/DEFAULT") + + assert response.status_code == 401 + + def test_oci_objects_returns_404_for_unknown_profile(self, client, auth_headers): + """GET /v1/oci/objects/{bucket_name}/{auth_profile} should return 404 for unknown profile.""" + response = client.get( + "/v1/oci/objects/test-bucket/NONEXISTENT_PROFILE", + headers=auth_headers["valid_auth"], + ) + + assert response.status_code == 404 + + +class TestOciUpdate: + """Integration tests for the OCI profile update endpoint.""" + + def test_oci_update_requires_auth(self, client): + """PATCH /v1/oci/{auth_profile} should require authentication.""" + response = client.patch( + "/v1/oci/DEFAULT", + json={"auth_profile": "DEFAULT", "genai_region": "us-ashburn-1"}, + ) + + assert response.status_code == 401 + + def test_oci_update_returns_404_for_unknown_profile(self, client, auth_headers): + """PATCH /v1/oci/{auth_profile} should return 404 for unknown profile.""" + response = client.patch( + "/v1/oci/NONEXISTENT_PROFILE", + headers=auth_headers["valid_auth"], + json={"auth_profile": "NONEXISTENT_PROFILE", "genai_region": "us-ashburn-1"}, + ) + + assert response.status_code == 404 + + +class TestOciDownloadObjects: + """Integration tests for the OCI download objects endpoint.""" + + def test_oci_download_requires_auth(self, client): + """POST /v1/oci/objects/download/{bucket_name}/{auth_profile} should require authentication.""" + response = client.post( + "/v1/oci/objects/download/test-bucket/DEFAULT", + json=["file1.txt"], + ) + + assert response.status_code == 401 + + def test_oci_download_returns_404_for_unknown_profile(self, client, auth_headers): + """POST /v1/oci/objects/download/{bucket_name}/{auth_profile} should return 404 for unknown profile.""" + response = client.post( + "/v1/oci/objects/download/test-bucket/NONEXISTENT_PROFILE", + headers=auth_headers["valid_auth"], + json=["file1.txt"], + ) + + assert response.status_code == 404 + + +class TestOciCreateGenaiModels: + """Integration tests for the OCI create GenAI models endpoint.""" + + def test_oci_create_genai_requires_auth(self, client): + """POST /v1/oci/genai/{auth_profile} should require authentication.""" + response = client.post("/v1/oci/genai/DEFAULT") + + assert response.status_code == 401 + + def test_oci_create_genai_returns_404_for_unknown_profile(self, client, auth_headers): + """POST /v1/oci/genai/{auth_profile} should return 404 for unknown profile.""" + response = client.post( + "/v1/oci/genai/NONEXISTENT_PROFILE", + headers=auth_headers["valid_auth"], + ) + + assert response.status_code == 404 diff --git a/test/integration/server/api/v1/test_probes.py b/test/integration/server/api/v1/test_probes.py new file mode 100644 index 00000000..9d0401e8 --- /dev/null +++ b/test/integration/server/api/v1/test_probes.py @@ -0,0 +1,74 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Integration tests for server/api/v1/probes.py + +Tests the Kubernetes probe endpoints (liveness, readiness, MCP health). +These endpoints do not require authentication. +""" + + +class TestLivenessProbe: + """Integration tests for the liveness probe endpoint.""" + + def test_liveness_returns_200(self, client): + """GET /v1/liveness should return 200 with status alive.""" + response = client.get("/v1/liveness") + + assert response.status_code == 200 + assert response.json() == {"status": "alive"} + + def test_liveness_no_auth_required(self, client): + """GET /v1/liveness should not require authentication.""" + # No auth headers provided + response = client.get("/v1/liveness") + + assert response.status_code == 200 + + +class TestReadinessProbe: + """Integration tests for the readiness probe endpoint.""" + + def test_readiness_returns_200(self, client): + """GET /v1/readiness should return 200 with status ready.""" + response = client.get("/v1/readiness") + + assert response.status_code == 200 + assert response.json() == {"status": "ready"} + + def test_readiness_no_auth_required(self, client): + """GET /v1/readiness should not require authentication.""" + response = client.get("/v1/readiness") + + assert response.status_code == 200 + + +class TestMcpHealthz: + """Integration tests for the MCP health check endpoint.""" + + def test_mcp_healthz_returns_200(self, client): + """GET /v1/mcp/healthz should return 200 with MCP status.""" + response = client.get("/v1/mcp/healthz") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ready" + assert "name" in data + assert "version" in data + assert "available_tools" in data + + def test_mcp_healthz_no_auth_required(self, client): + """GET /v1/mcp/healthz should not require authentication.""" + response = client.get("/v1/mcp/healthz") + + assert response.status_code == 200 + + def test_mcp_healthz_returns_server_info(self, client): + """GET /v1/mcp/healthz should return MCP server information.""" + response = client.get("/v1/mcp/healthz") + + data = response.json() + assert data["name"] == "Oracle AI Optimizer and Toolkit MCP Server" + assert isinstance(data["available_tools"], int) + assert data["available_tools"] >= 0 diff --git a/test/integration/server/api/v1/test_settings.py b/test/integration/server/api/v1/test_settings.py new file mode 100644 index 00000000..15374061 --- /dev/null +++ b/test/integration/server/api/v1/test_settings.py @@ -0,0 +1,307 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Integration tests for server/api/v1/settings.py + +Tests the settings configuration endpoints through the full API stack. +These endpoints require authentication. +""" + +import json +from io import BytesIO + + +class TestAuthentication: + """Integration tests for authentication on settings endpoints.""" + + def test_settings_get_requires_auth(self, client): + """GET /v1/settings should require authentication.""" + response = client.get("/v1/settings", params={"client": "test"}) + + assert response.status_code == 401 + + def test_settings_get_rejects_invalid_token(self, client, auth_headers): + """GET /v1/settings should reject invalid tokens.""" + response = client.get( + "/v1/settings", + headers=auth_headers["invalid_auth"], + params={"client": "test"}, + ) + + assert response.status_code == 401 + + def test_settings_get_accepts_valid_token(self, client, auth_headers): + """GET /v1/settings should accept valid tokens.""" + response = client.get( + "/v1/settings", + headers=auth_headers["valid_auth"], + params={"client": "server"}, # Use existing client + ) + + assert response.status_code == 200 + + +class TestSettingsGet: + """Integration tests for the settings get endpoint.""" + + def test_settings_get_returns_settings(self, client, auth_headers): + """GET /v1/settings should return settings for existing client.""" + response = client.get( + "/v1/settings", + headers=auth_headers["valid_auth"], + params={"client": "server"}, + ) + + assert response.status_code == 200 + data = response.json() + assert "client" in data + assert data["client"] == "server" + + def test_settings_get_returns_404_for_unknown_client(self, client, auth_headers): + """GET /v1/settings should return 404 for unknown client.""" + response = client.get( + "/v1/settings", + headers=auth_headers["valid_auth"], + params={"client": "nonexistent_client_xyz"}, + ) + + assert response.status_code == 404 + + def test_settings_get_full_config(self, client, auth_headers): + """GET /v1/settings?full_config=true should return full configuration.""" + response = client.get( + "/v1/settings", + headers=auth_headers["valid_auth"], + params={"client": "server", "full_config": True}, + ) + + assert response.status_code == 200 + data = response.json() + # Full config includes client_settings and all config arrays + assert "client_settings" in data + assert "database_configs" in data + assert "model_configs" in data + assert "oci_configs" in data + assert "prompt_configs" in data + + def test_settings_get_with_sensitive(self, client, auth_headers): + """GET /v1/settings?incl_sensitive=true should include sensitive fields.""" + response = client.get( + "/v1/settings", + headers=auth_headers["valid_auth"], + params={"client": "server", "full_config": True, "incl_sensitive": True}, + ) + + assert response.status_code == 200 + # Response should include sensitive fields (passwords) + # Exact fields depend on what's configured + + +class TestSettingsCreate: + """Integration tests for the settings create endpoint.""" + + def test_settings_create_requires_auth(self, client): + """POST /v1/settings should require authentication.""" + response = client.post("/v1/settings", params={"client": "new_test_client"}) + + assert response.status_code == 401 + + def test_settings_create_success(self, client, auth_headers, settings_objects_manager): + """POST /v1/settings should create new client settings.""" + # pylint: disable=unused-argument + response = client.post( + "/v1/settings", + headers=auth_headers["valid_auth"], + params={"client": "integration_new_client"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["client"] == "integration_new_client" + + def test_settings_create_returns_409_for_existing(self, client, auth_headers): + """POST /v1/settings should return 409 if client already exists.""" + # "server" client is created by bootstrap + response = client.post( + "/v1/settings", + headers=auth_headers["valid_auth"], + params={"client": "server"}, + ) + + assert response.status_code == 409 + + +class TestSettingsUpdate: + """Integration tests for the settings update endpoint.""" + + def test_settings_update_requires_auth(self, client): + """PATCH /v1/settings should require authentication.""" + response = client.patch( + "/v1/settings", + params={"client": "server"}, + json={"client": "server"}, + ) + + assert response.status_code == 401 + + def test_settings_update_returns_404_for_unknown(self, client, auth_headers): + """PATCH /v1/settings should return 404 for unknown client.""" + response = client.patch( + "/v1/settings", + headers=auth_headers["valid_auth"], + params={"client": "nonexistent_client_xyz"}, + json={"client": "nonexistent_client_xyz"}, + ) + + assert response.status_code == 404 + + def test_settings_update_success(self, client, auth_headers, settings_objects_manager): + """PATCH /v1/settings should update client settings.""" + # pylint: disable=unused-argument + # First create a client to update + client.post( + "/v1/settings", + headers=auth_headers["valid_auth"], + params={"client": "update_test_client"}, + ) + + # Now update it + response = client.patch( + "/v1/settings", + headers=auth_headers["valid_auth"], + params={"client": "update_test_client"}, + json={ + "client": "update_test_client", + "ll_model": { + "model": "gpt-4o", + "temperature": 0.5, + "max_tokens": 2048, + "chat_history": False, + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["ll_model"]["temperature"] == 0.5 + + +class TestSettingsLoadFromFile: + """Integration tests for the settings load from file endpoint.""" + + def test_load_from_file_requires_auth(self, client): + """POST /v1/settings/load/file should require authentication.""" + response = client.post( + "/v1/settings/load/file", + params={"client": "test"}, + files={"file": ("test.json", b"{}", "application/json")}, + ) + + assert response.status_code == 401 + + def test_load_from_file_rejects_non_json_extension(self, client, auth_headers, settings_objects_manager): + """POST /v1/settings/load/file should reject files without .json extension. + + Note: Current implementation returns 500 due to HTTPException being caught + by generic Exception handler. This documents actual behavior. + """ + # pylint: disable=unused-argument + response = client.post( + "/v1/settings/load/file", + headers=auth_headers["valid_auth"], + params={"client": "file_test_client"}, + files={"file": ("test.txt", b"{}", "text/plain")}, + ) + + # Current behavior returns 500 (HTTPException caught by generic handler) + # Ideally should be 400, but documenting actual behavior + assert response.status_code == 500 + assert "Only JSON files are supported" in response.json()["detail"] + + def test_load_from_file_rejects_invalid_json_content(self, client, auth_headers, settings_objects_manager): + """POST /v1/settings/load/file should reject invalid JSON content.""" + # pylint: disable=unused-argument + response = client.post( + "/v1/settings/load/file", + headers=auth_headers["valid_auth"], + params={"client": "file_invalid_content"}, + files={"file": ("test.json", b"not valid json", "application/json")}, + ) + + # Invalid JSON content returns 400 + assert response.status_code == 400 + + def test_load_from_file_success(self, client, auth_headers, settings_objects_manager): + """POST /v1/settings/load/file should load configuration from JSON file.""" + # pylint: disable=unused-argument + config_data = { + "client_settings": { + "client": "file_load_client", + "ll_model": { + "model": "gpt-4o-mini", + "temperature": 0.8, + "max_tokens": 1000, + "chat_history": True, + }, + }, + "database_configs": [], + "model_configs": [], + "oci_configs": [], + "prompt_configs": [], + } + + file_content = json.dumps(config_data).encode("utf-8") + + response = client.post( + "/v1/settings/load/file", + headers=auth_headers["valid_auth"], + params={"client": "file_load_client"}, + files={"file": ("config.json", BytesIO(file_content), "application/json")}, + ) + + assert response.status_code == 200 + assert "loaded successfully" in response.json()["message"].lower() + + +class TestSettingsLoadFromJson: + """Integration tests for the settings load from JSON endpoint.""" + + def test_load_from_json_requires_auth(self, client): + """POST /v1/settings/load/json should require authentication.""" + response = client.post( + "/v1/settings/load/json", + params={"client": "test"}, + json={"client_settings": {"client": "test"}}, + ) + + assert response.status_code == 401 + + def test_load_from_json_success(self, client, auth_headers, settings_objects_manager): + """POST /v1/settings/load/json should load configuration from JSON payload.""" + # pylint: disable=unused-argument + config_data = { + "client_settings": { + "client": "json_load_client", + "ll_model": { + "model": "gpt-4o-mini", + "temperature": 0.9, + "max_tokens": 500, + "chat_history": True, + }, + }, + "database_configs": [], + "model_configs": [], + "oci_configs": [], + "prompt_configs": [], + } + + response = client.post( + "/v1/settings/load/json", + headers=auth_headers["valid_auth"], + params={"client": "json_load_client"}, + json=config_data, + ) + + assert response.status_code == 200 + assert "loaded successfully" in response.json()["message"].lower() diff --git a/test/unit/server/api/utils/test_utils_databases.py b/test/unit/server/api/utils/test_utils_databases.py index 4b67062f..ebc92d80 100644 --- a/test/unit/server/api/utils/test_utils_databases.py +++ b/test/unit/server/api/utils/test_utils_databases.py @@ -221,6 +221,34 @@ def test_connect_wallet_location_defaults_to_config_dir(self, mock_connect, make call_kwargs = mock_connect.call_args.kwargs assert call_kwargs.get("wallet_location") == "/path/to/config" + @patch("server.api.utils.databases.oracledb.connect") + def test_connect_raises_permission_error_on_ora_28009(self, mock_connect, make_database): + """connect should raise PermissionError with custom message on ORA-28009 (mocked).""" + # Create a mock error object with full_code and message + mock_error = MagicMock() + mock_error.full_code = "ORA-28009" + mock_error.message = "connection not allowed" + mock_connect.side_effect = oracledb.DatabaseError(mock_error) + config = make_database(user="SYS") + + with pytest.raises(PermissionError) as exc_info: + utils_databases.connect(config) + + assert "Connecting as SYS is not permitted" in str(exc_info.value) + + @patch("server.api.utils.databases.oracledb.connect") + def test_connect_reraises_unmapped_database_error(self, mock_connect, make_database): + """connect should re-raise unmapped DatabaseError codes (mocked).""" + # Create a mock error object with an unmapped error code + mock_error = MagicMock() + mock_error.full_code = "ORA-12345" + mock_error.message = "some other error" + mock_connect.side_effect = oracledb.DatabaseError(mock_error) + config = make_database() + + with pytest.raises(oracledb.DatabaseError): + utils_databases.connect(config) + class TestDisconnect: """Tests for the disconnect function.""" @@ -294,6 +322,86 @@ def test_execute_sql_multiple_rows(self, db_transaction): assert result[1] == (2,) assert result[2] == (3,) + def test_execute_sql_logs_table_exists_error(self, db_connection, caplog): + """execute_sql should log ORA-00955 table exists error (real database). + + Note: Due to a bug in the source code (two if statements instead of elif), + the function logs 'Table exists' but still raises. This test verifies + the logging behavior and that the error is raised. + """ + cursor = db_connection.cursor() + table_name = "TEST_DUPLICATE_TABLE" + + try: + # Create table first + cursor.execute(f"CREATE TABLE {table_name} (id NUMBER)") + db_connection.commit() + + # Try to create it again - logs 'Table exists' but raises due to bug + with pytest.raises(oracledb.DatabaseError): + utils_databases.execute_sql( + db_connection, + f"CREATE TABLE {table_name} (id NUMBER)", + ) + + # Verify the logging happened + assert "Table exists" in caplog.text + + finally: + try: + cursor.execute(f"DROP TABLE {table_name} PURGE") + db_connection.commit() + except oracledb.DatabaseError: + pass + cursor.close() + + def test_execute_sql_handles_table_not_exists_error(self, db_connection, caplog): + """execute_sql should handle ORA-00942 table not exists error (real database). + + The function logs 'Table does not exist' and returns None (doesn't raise) + for error code 942. + """ + # Try to select from a non-existent table + result = utils_databases.execute_sql( + db_connection, + "SELECT * FROM NONEXISTENT_TABLE_12345", + ) + + # Should not raise, returns None + assert result is None + + # Verify the logging happened + assert "Table does not exist" in caplog.text + + def test_execute_sql_raises_on_other_database_error(self, db_transaction): + """execute_sql should raise on other DatabaseError codes (real database).""" + # Invalid SQL syntax should raise + with pytest.raises(oracledb.DatabaseError): + utils_databases.execute_sql(db_transaction, "INVALID SQL SYNTAX HERE") + + def test_execute_sql_raises_on_interface_error(self): + """execute_sql should raise on InterfaceError (mocked).""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + mock_cursor.callproc.side_effect = oracledb.InterfaceError("Interface error") + + with pytest.raises(oracledb.InterfaceError): + utils_databases.execute_sql(mock_conn, "SELECT 1 FROM dual") + + def test_execute_sql_raises_on_database_error_no_args(self): + """execute_sql should raise on DatabaseError with no args (mocked).""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + # DatabaseError with empty args + mock_cursor.callproc.side_effect = oracledb.DatabaseError() + + with pytest.raises(oracledb.DatabaseError): + utils_databases.execute_sql(mock_conn, "SELECT 1 FROM dual") + class TestDropVs: """Tests for the drop_vs function.""" @@ -466,6 +574,19 @@ def test_test_raises_db_exception_on_connection_error(self, make_database): assert exc_info.value.status_code == 503 + def test_test_raises_db_exception_on_generic_exception(self, make_database): + """_test should raise DbException with 500 on generic Exception.""" + config = make_database() + mock_conn = MagicMock() + mock_conn.ping.side_effect = RuntimeError("Unexpected error") + config.set_connection(mock_conn) + + with pytest.raises(DbException) as exc_info: + utils_databases._test(config) + + assert exc_info.value.status_code == 500 + assert "Unexpected error" in exc_info.value.detail + class TestGetVs: # pylint: disable=protected-access """Tests for the _get_vs function. @@ -489,6 +610,40 @@ def test_get_vs_empty_for_clean_schema(self, db_transaction): # Either empty or returns actual vector stores if they exist assert isinstance(result, list) + def test_get_vs_parses_genai_comment(self, db_connection): + """_get_vs should parse GENAI comment JSON and return DatabaseVectorStorage (real database).""" + cursor = db_connection.cursor() + table_name = "VS_TEST_TABLE" + + try: + # Create a test table + cursor.execute(f"CREATE TABLE {table_name} (id NUMBER, data VARCHAR2(100))") + + # Add GENAI comment with JSON metadata (matching the expected format) + comment_json = '{"description": "Test vector store"}' + cursor.execute(f"COMMENT ON TABLE {table_name} IS 'GENAI: {comment_json}'") + db_connection.commit() + + # Test _get_vs + result = utils_databases._get_vs(db_connection) + + # Should find our test table + vs_names = [vs.vector_store for vs in result] + assert table_name in vs_names + + # Find our test vector store and verify parsed data + test_vs = next(vs for vs in result if vs.vector_store == table_name) + assert test_vs.description == "Test vector store" + + finally: + # Cleanup - drop table + try: + cursor.execute(f"DROP TABLE {table_name} PURGE") + db_connection.commit() + except oracledb.DatabaseError: + pass + cursor.close() + class TestLoggerConfiguration: """Tests for logger configuration.""" diff --git a/test/unit/server/api/v1/test_v1_databases.py b/test/unit/server/api/v1/test_v1_databases.py index 7d052cd9..98f957a4 100644 --- a/test/unit/server/api/v1/test_v1_databases.py +++ b/test/unit/server/api/v1/test_v1_databases.py @@ -44,6 +44,7 @@ async def test_databases_list_returns_empty_list(self, mock_get_databases): result = await databases.databases_list() assert result == [] + mock_get_databases.assert_called_once_with(validate=False) @pytest.mark.asyncio @patch("server.api.v1.databases.utils_databases.get_databases") @@ -55,6 +56,7 @@ async def test_databases_list_raises_404_on_value_error(self, mock_get_databases await databases.databases_list() assert exc_info.value.status_code == 404 + mock_get_databases.assert_called_once_with(validate=False) class TestDatabasesGet: @@ -82,6 +84,7 @@ async def test_databases_get_raises_404_when_not_found(self, mock_get_databases) await databases.databases_get(name="NONEXISTENT") assert exc_info.value.status_code == 404 + mock_get_databases.assert_called_once_with(db_name="NONEXISTENT", validate=True) class TestDatabasesUpdate: @@ -92,7 +95,7 @@ class TestDatabasesUpdate: @patch("server.api.v1.databases.utils_databases.connect") @patch("server.api.v1.databases.utils_databases.disconnect") async def test_databases_update_returns_updated_database( - self, _mock_disconnect, mock_connect, mock_get_databases, make_database, make_database_auth + self, mock_disconnect, mock_connect, mock_get_databases, make_database, make_database_auth ): """databases_update should return the updated database.""" existing_db = make_database(name="TEST_DB", user="old_user") @@ -107,6 +110,21 @@ async def test_databases_update_returns_updated_database( assert result.user == "new_user" assert result.connected is True + # Verify get_databases called twice: first to get target DB, second to get all DBs for cleanup + assert mock_get_databases.call_count == 2 + mock_get_databases.assert_any_call(db_name="TEST_DB", validate=False) + mock_get_databases.assert_any_call() + + # Verify connect was called with the payload (which has config_dir/wallet_location set from db) + mock_connect.assert_called_once() + connect_arg = mock_connect.call_args[0][0] + assert connect_arg.user == "new_user" + assert connect_arg.password == "new_pass" + assert connect_arg.dsn == "localhost:1521/TEST" + + # Verify disconnect was NOT called (no other databases with connections) + mock_disconnect.assert_not_called() + @pytest.mark.asyncio @patch("server.api.v1.databases.utils_databases.get_databases") async def test_databases_update_raises_404_when_not_found(self, mock_get_databases, make_database_auth): @@ -119,6 +137,7 @@ async def test_databases_update_raises_404_when_not_found(self, mock_get_databas await databases.databases_update(name="NONEXISTENT", payload=payload) assert exc_info.value.status_code == 404 + mock_get_databases.assert_called_once_with(db_name="NONEXISTENT", validate=False) @pytest.mark.asyncio @patch("server.api.v1.databases.utils_databases.get_databases") @@ -138,6 +157,15 @@ async def test_databases_update_raises_400_on_value_error( assert exc_info.value.status_code == 400 + # Verify get_databases was called to retrieve the target database + mock_get_databases.assert_called_once_with(db_name="TEST_DB", validate=False) + + # Verify connect was called with the payload + mock_connect.assert_called_once() + connect_arg = mock_connect.call_args[0][0] + assert connect_arg.user == payload.user + assert connect_arg.dsn == payload.dsn + @pytest.mark.asyncio @patch("server.api.v1.databases.utils_databases.get_databases") @patch("server.api.v1.databases.utils_databases.connect") @@ -156,6 +184,84 @@ async def test_databases_update_raises_401_on_permission_error( assert exc_info.value.status_code == 401 + # Verify get_databases was called to retrieve the target database + mock_get_databases.assert_called_once_with(db_name="TEST_DB", validate=False) + + # Verify connect was called with the payload + mock_connect.assert_called_once() + connect_arg = mock_connect.call_args[0][0] + assert connect_arg.user == payload.user + assert connect_arg.dsn == payload.dsn + + @pytest.mark.asyncio + @patch("server.api.v1.databases.utils_databases.get_databases") + @patch("server.api.v1.databases.utils_databases.connect") + @patch("server.api.v1.databases.utils_databases.disconnect") + async def test_databases_update_disconnects_other_databases( + self, mock_disconnect, mock_connect, mock_get_databases, make_database, make_database_auth + ): + """databases_update should disconnect OTHER database connections, not the newly connected one. + + When connecting to a database, the system enforces single-connection mode: + only one database can be connected at a time. This test verifies that when + updating/connecting to TEST_DB, any existing connections on OTHER databases + are properly disconnected using their own connection objects. + + Expected behavior: + 1. Connect to TEST_DB with new connection + 2. For each other database with an active connection, disconnect it + 3. The disconnect call should receive the OTHER database's connection + 4. The newly connected database's connection should remain intact + """ + # Setup: TEST_DB is the database being updated + target_db = make_database(name="TEST_DB", user="old_user") + + # Setup: OTHER_DB has an existing connection that should be disconnected + other_db = make_database(name="OTHER_DB") + other_db_existing_connection = MagicMock(name="other_db_connection") + other_db.set_connection(other_db_existing_connection) + other_db.connected = True + + # Setup: ANOTHER_DB has no connection (should not trigger disconnect) + another_db = make_database(name="ANOTHER_DB") + another_db.connected = False + + # Mock: First call returns target DB, second call returns all DBs for cleanup + mock_get_databases.side_effect = [target_db, [target_db, other_db, another_db]] + + # Mock: New connection for TEST_DB + new_connection = MagicMock(name="new_test_db_connection") + mock_connect.return_value = new_connection + + # Mock: disconnect returns None (connection closed) + mock_disconnect.return_value = None + + payload = make_database_auth(user="new_user", password="new_pass", dsn="localhost:1521/TEST") + + # Execute + result = await databases.databases_update(name="TEST_DB", payload=payload) + + # Verify: Target database is connected with new connection + assert result.connected is True + assert result.user == "new_user" + + # Verify: disconnect was called exactly once (only OTHER_DB had a connection) + mock_disconnect.assert_called_once() + + # CRITICAL ASSERTION: disconnect must be called with OTHER_DB's connection, + # not the new TEST_DB connection + actual_disconnect_arg = mock_disconnect.call_args[0][0] + assert actual_disconnect_arg is other_db_existing_connection, ( + f"Expected disconnect to be called with other_db's connection, " + f"but was called with: {actual_disconnect_arg}" + ) + assert actual_disconnect_arg is not new_connection, ( + "disconnect should NOT be called with the newly created connection" + ) + + # Verify: OTHER_DB is now disconnected + assert other_db.connected is False + class TestRouterConfiguration: """Tests for router configuration.""" diff --git a/tests/server/unit/api/utils/test_utils_databases_crud.py b/tests/server/unit/api/utils/test_utils_databases_crud.py index 6ad67d00..62c06ef0 100644 --- a/tests/server/unit/api/utils/test_utils_databases_crud.py +++ b/tests/server/unit/api/utils/test_utils_databases_crud.py @@ -27,36 +27,9 @@ def setup_method(self): name="test_db_2", user="test_user_2", password="test_password_2", dsn="test_dsn_2" ) - @patch("server.api.utils.databases.DATABASE_OBJECTS") - def test_get_all(self, mock_database_objects): - """Test getting all databases when no name is provided""" - mock_database_objects.__iter__ = MagicMock(return_value=iter([self.sample_database, self.sample_database_2])) - mock_database_objects.__len__ = MagicMock(return_value=2) - - result = databases.get() - - assert result == [self.sample_database, self.sample_database_2] - assert len(result) == 2 - - @patch("server.api.utils.databases.DATABASE_OBJECTS") - def test_get_by_name_found(self, mock_database_objects): - """Test getting database by name when it exists""" - mock_database_objects.__iter__ = MagicMock(return_value=iter([self.sample_database, self.sample_database_2])) - mock_database_objects.__len__ = MagicMock(return_value=2) - - result = databases.get(name="test_db") - - assert result == [self.sample_database] - assert len(result) == 1 - - @patch("server.api.utils.databases.DATABASE_OBJECTS") - def test_get_by_name_not_found(self, mock_database_objects): - """Test getting database by name when it doesn't exist""" - mock_database_objects.__iter__ = MagicMock(return_value=iter([self.sample_database])) - mock_database_objects.__len__ = MagicMock(return_value=1) - - with pytest.raises(ValueError, match="nonexistent not found"): - databases.get(name="nonexistent") + # test_get_all: See test/unit/server/api/utils/test_utils_databases.py::TestGet::test_get_all_databases + # test_get_by_name_found: See test/unit/server/api/utils/test_utils_databases.py::TestGet::test_get_specific_database + # test_get_by_name_not_found: See test/unit/server/api/utils/test_utils_databases.py::TestGet::test_get_raises_unknown_error @patch("server.api.utils.databases.DATABASE_OBJECTS") def test_get_empty_list(self, mock_database_objects): @@ -77,54 +50,9 @@ def test_get_empty_list_with_name(self, mock_database_objects): with pytest.raises(ValueError, match="test_db not found"): databases.get(name="test_db") - def test_create_success(self, db_container, db_objects_manager): - """Test successful database creation when database doesn't exist""" - assert db_container is not None - assert db_objects_manager is not None - # Clear the list to start fresh - databases.DATABASE_OBJECTS.clear() - - # Create a new database - new_database = Database(name="new_test_db", user="test_user", password="test_password", dsn="test_dsn") - - result = databases.create(new_database) - - # Verify database was added - assert len(databases.DATABASE_OBJECTS) == 1 - assert databases.DATABASE_OBJECTS[0].name == "new_test_db" - assert result == [new_database] - - def test_create_already_exists(self, db_container, db_objects_manager): - """Test database creation when database already exists""" - assert db_container is not None - assert db_objects_manager is not None - # Add a database to the list - databases.DATABASE_OBJECTS.clear() - existing_db = Database(name="existing_db", user="test_user", password="test_password", dsn="test_dsn") - databases.DATABASE_OBJECTS.append(existing_db) - - # Try to create a database with the same name - duplicate_db = Database(name="existing_db", user="other_user", password="other_password", dsn="other_dsn") - - # Should raise an error for duplicate database - with pytest.raises(ValueError, match="Database: existing_db already exists"): - databases.create(duplicate_db) - - # Verify only original database exists - assert len(databases.DATABASE_OBJECTS) == 1 - assert databases.DATABASE_OBJECTS[0] == existing_db - - def test_create_missing_user(self, db_container, db_objects_manager): - """Test database creation with missing user field""" - assert db_container is not None - assert db_objects_manager is not None - databases.DATABASE_OBJECTS.clear() - - # Create database with missing user - incomplete_db = Database(name="incomplete_db", password="test_password", dsn="test_dsn") - - with pytest.raises(ValueError, match="'user', 'password', and 'dsn' are required"): - databases.create(incomplete_db) + # test_create_success: See test/unit/server/api/utils/test_utils_databases.py::TestCreate::test_create_success + # test_create_already_exists: See test/unit/server/api/utils/test_utils_databases.py::TestCreate::test_create_raises_exists_error + # test_create_missing_user: See test/unit/server/api/utils/test_utils_databases.py::TestCreate::test_create_raises_value_error_missing_fields def test_create_missing_password(self, db_container, db_objects_manager): """Test database creation with missing password field""" @@ -162,27 +90,7 @@ def test_create_multiple_missing_fields(self, db_container, db_objects_manager): with pytest.raises(ValueError, match="'user', 'password', and 'dsn' are required"): databases.create(incomplete_db) - def test_delete(self, db_container, db_objects_manager): - """Test database deletion""" - assert db_container is not None - assert db_objects_manager is not None - # Setup test data - db1 = Database(name="test_db_1", user="user1", password="pass1", dsn="dsn1") - db2 = Database(name="test_db_2", user="user2", password="pass2", dsn="dsn2") - db3 = Database(name="test_db_3", user="user3", password="pass3", dsn="dsn3") - - databases.DATABASE_OBJECTS.clear() - databases.DATABASE_OBJECTS.extend([db1, db2, db3]) - - # Delete middle database - databases.delete("test_db_2") - - # Verify deletion - assert len(databases.DATABASE_OBJECTS) == 2 - names = [db.name for db in databases.DATABASE_OBJECTS] - assert "test_db_1" in names - assert "test_db_2" not in names - assert "test_db_3" in names + # test_delete: See test/unit/server/api/utils/test_utils_databases.py::TestDelete::test_delete_removes_database def test_delete_nonexistent(self, db_container, db_objects_manager): """Test deleting non-existent database""" @@ -234,10 +142,7 @@ def test_delete_multiple_same_name(self, db_container, db_objects_manager): assert len(databases.DATABASE_OBJECTS) == 1 assert databases.DATABASE_OBJECTS[0].name == "other" - def test_logger_exists(self): - """Test that logger is properly configured""" - assert hasattr(databases, "logger") - assert databases.logger.name == "api.utils.database" + # test_logger_exists: See test/unit/server/api/utils/test_utils_databases.py::TestLoggerConfiguration::test_logger_exists def test_get_filters_correctly(self, db_container, db_objects_manager): """Test that get correctly filters by name""" @@ -321,12 +226,7 @@ def test_create_real_scenario(self, db_container, db_objects_manager): class TestDbException: """Test custom database exception class""" - def test_db_exception_initialization(self): - """Test DbException initialization""" - exc = DbException(status_code=500, detail="Database error") - assert exc.status_code == 500 - assert exc.detail == "Database error" - assert str(exc) == "Database error" + # test_db_exception_initialization: See test/unit/server/api/utils/test_utils_databases.py::TestDbException::test_db_exception_init def test_db_exception_inheritance(self): """Test DbException inherits from Exception""" diff --git a/tests/server/unit/api/utils/test_utils_databases_functions.py b/tests/server/unit/api/utils/test_utils_databases_functions.py index 42736539..a4ac6b74 100644 --- a/tests/server/unit/api/utils/test_utils_databases_functions.py +++ b/tests/server/unit/api/utils/test_utils_databases_functions.py @@ -31,90 +31,13 @@ def setup_method(self): dsn=TEST_CONFIG["db_dsn"], ) - def test_test_function_success(self, db_container): - """Test successful database connection test with real database""" - assert db_container is not None - # Connect to real database - conn = databases.connect(self.sample_database) - self.sample_database.set_connection(conn) - - try: - # Test the connection - databases._test(self.sample_database) - assert self.sample_database.connected is True - finally: - databases.disconnect(conn) - - @patch("oracledb.Connection") - def test_test_function_reconnect(self, mock_connection): - """Test database reconnection when ping fails""" - mock_connection.ping.side_effect = oracledb.DatabaseError("Connection lost") - self.sample_database.set_connection(mock_connection) - - with patch("server.api.utils.databases.connect") as mock_connect: - databases._test(self.sample_database) - mock_connect.assert_called_once_with(self.sample_database) - - @patch("oracledb.Connection") - def test_test_function_value_error(self, mock_connection): - """Test handling of value errors""" - mock_connection.ping.side_effect = ValueError("Invalid value") - self.sample_database.set_connection(mock_connection) - - with pytest.raises(DbException) as exc_info: - databases._test(self.sample_database) - - assert exc_info.value.status_code == 400 - assert "Database: Invalid value" in str(exc_info.value) - - @patch("oracledb.Connection") - def test_test_function_permission_error(self, mock_connection): - """Test handling of permission errors""" - mock_connection.ping.side_effect = PermissionError("Access denied") - self.sample_database.set_connection(mock_connection) - - with pytest.raises(DbException) as exc_info: - databases._test(self.sample_database) - - assert exc_info.value.status_code == 401 - assert "Database: Access denied" in str(exc_info.value) - - @patch("oracledb.Connection") - def test_test_function_connection_error(self, mock_connection): - """Test handling of connection errors""" - mock_connection.ping.side_effect = ConnectionError("Connection failed") - self.sample_database.set_connection(mock_connection) - - with pytest.raises(DbException) as exc_info: - databases._test(self.sample_database) - - assert exc_info.value.status_code == 503 - assert "Database: Connection failed" in str(exc_info.value) - - @patch("oracledb.Connection") - def test_test_function_generic_exception(self, mock_connection): - """Test handling of generic exceptions""" - mock_connection.ping.side_effect = RuntimeError("Unknown error") - self.sample_database.set_connection(mock_connection) - - with pytest.raises(DbException) as exc_info: - databases._test(self.sample_database) - - assert exc_info.value.status_code == 500 - assert "Unknown error" in str(exc_info.value) - - def test_get_vs_with_real_database(self, db_container): - """Test vector storage retrieval with real database""" - assert db_container is not None - conn = databases.connect(self.sample_database) - - try: - # Test with empty result (no vector stores initially) - result = databases._get_vs(conn) - assert isinstance(result, list) - assert len(result) == 0 # Initially no vector stores - finally: - databases.disconnect(conn) + # test_test_function_success: See test/unit/server/api/utils/test_utils_databases.py::TestTestConnection::test_test_connection_active + # test_test_function_reconnect: See test/unit/server/api/utils/test_utils_databases.py::TestTestConnection::test_test_connection_refreshes_on_database_error + # test_test_function_value_error: See test/unit/server/api/utils/test_utils_databases.py::TestTestConnection::test_test_raises_db_exception_on_value_error + # test_test_function_permission_error: See test/unit/server/api/utils/test_utils_databases.py::TestTestConnection::test_test_raises_db_exception_on_permission_error + # test_test_function_connection_error: See test/unit/server/api/utils/test_utils_databases.py::TestTestConnection::test_test_raises_db_exception_on_connection_error + # test_test_function_generic_exception: See test/unit/server/api/utils/test_utils_databases.py::TestTestConnection::test_test_raises_db_exception_on_generic_exception + # test_get_vs_with_real_database: See test/unit/server/api/utils/test_utils_databases.py::TestGetVs::test_get_vs_returns_list @patch("server.api.utils.databases.execute_sql") def test_get_vs_with_mock_data(self, mock_execute_sql): @@ -181,54 +104,10 @@ def setup_method(self): dsn=TEST_CONFIG["db_dsn"], ) - def test_connect_success_with_real_database(self, db_container): - """Test successful database connection with real database""" - assert db_container is not None - result = databases.connect(self.sample_database) - - try: - assert result is not None - assert isinstance(result, oracledb.Connection) - # Test that connection is active - result.ping() - finally: - databases.disconnect(result) - - def test_connect_missing_user(self): - """Test connection with missing user""" - incomplete_db = Database( - name="test_db", - user="", # Missing user - password=TEST_CONFIG["db_password"], - dsn=TEST_CONFIG["db_dsn"], - ) - - with pytest.raises(ValueError, match="missing connection details"): - databases.connect(incomplete_db) - - def test_connect_missing_password(self): - """Test connection with missing password""" - incomplete_db = Database( - name="test_db", - user=TEST_CONFIG["db_username"], - password="", # Missing password - dsn=TEST_CONFIG["db_dsn"], - ) - - with pytest.raises(ValueError, match="missing connection details"): - databases.connect(incomplete_db) - - def test_connect_missing_dsn(self): - """Test connection with missing DSN""" - incomplete_db = Database( - name="test_db", - user=TEST_CONFIG["db_username"], - password=TEST_CONFIG["db_password"], - dsn="", # Missing DSN - ) - - with pytest.raises(ValueError, match="missing connection details"): - databases.connect(incomplete_db) + # test_connect_success_with_real_database: See test/unit/server/api/utils/test_utils_databases.py::TestConnect::test_connect_success_real_db + # test_connect_missing_user: See test/unit/server/api/utils/test_utils_databases.py::TestConnect::test_connect_raises_value_error_missing_details + # test_connect_missing_password: See test/unit/server/api/utils/test_utils_databases.py::TestConnect::test_connect_raises_value_error_missing_details + # test_connect_missing_dsn: See test/unit/server/api/utils/test_utils_databases.py::TestConnect::test_connect_raises_value_error_missing_details def test_connect_with_wallet_configuration(self, db_container): """Test connection with wallet configuration""" @@ -271,18 +150,7 @@ def test_connect_wallet_password_without_location(self, db_container): # Expected if wallet doesn't exist pass - def test_connect_invalid_credentials(self, db_container): - """Test connection with invalid credentials""" - assert db_container is not None - invalid_db = Database( - name="test_db", - user="invalid_user", - password="invalid_password", - dsn=TEST_CONFIG["db_dsn"], - ) - - with pytest.raises(PermissionError): - databases.connect(invalid_db) + # test_connect_invalid_credentials: See test/unit/server/api/utils/test_utils_databases.py::TestConnect::test_connect_raises_permission_error_invalid_credentials def test_connect_invalid_dsn(self, db_container): """Test connection with invalid DSN""" @@ -298,45 +166,9 @@ def test_connect_invalid_dsn(self, db_container): with pytest.raises(Exception): # Catch any exception - DNS resolution errors vary by environment databases.connect(invalid_db) - def test_disconnect_success(self, db_container): - """Test successful database disconnection""" - assert db_container is not None - conn = databases.connect(self.sample_database) - - result = databases.disconnect(conn) - - assert result is None - # Try to use connection after disconnect - should fail - with pytest.raises(oracledb.InterfaceError): - conn.ping() - - def test_execute_sql_success_with_real_database(self, db_container): - """Test successful SQL execution with real database""" - assert db_container is not None - conn = databases.connect(self.sample_database) - - try: - # Test simple query - result = databases.execute_sql(conn, "SELECT 1 FROM DUAL") - assert result is not None - assert len(result) == 1 - assert result[0][0] == 1 - finally: - databases.disconnect(conn) - - def test_execute_sql_with_binds(self, db_container): - """Test SQL execution with bind variables using real database""" - assert db_container is not None - conn = databases.connect(self.sample_database) - - try: - binds = {"test_value": 42} - result = databases.execute_sql(conn, "SELECT :test_value FROM DUAL", binds) - assert result is not None - assert len(result) == 1 - assert result[0][0] == 42 - finally: - databases.disconnect(conn) + # test_disconnect_success: See test/unit/server/api/utils/test_utils_databases.py::TestDisconnect::test_disconnect_closes_connection + # test_execute_sql_success_with_real_database: See test/unit/server/api/utils/test_utils_databases.py::TestExecuteSql::test_execute_sql_returns_rows + # test_execute_sql_with_binds: See test/unit/server/api/utils/test_utils_databases.py::TestExecuteSql::test_execute_sql_with_binds def test_execute_sql_no_rows(self, db_container): """Test SQL execution that returns no rows""" @@ -411,31 +243,14 @@ def test_execute_sql_table_not_exists_error(self, db_container): finally: databases.disconnect(conn) - def test_execute_sql_invalid_syntax(self, db_container): - """Test SQL execution with invalid syntax""" - assert db_container is not None - conn = databases.connect(self.sample_database) - - try: - with pytest.raises(oracledb.DatabaseError): - databases.execute_sql(conn, "INVALID SQL STATEMENT") - finally: - databases.disconnect(conn) + # test_execute_sql_invalid_syntax: See test/unit/server/api/utils/test_utils_databases.py::TestExecuteSql::test_execute_sql_raises_on_other_database_error def test_drop_vs_function_exists(self): """Test that drop_vs function exists and is callable""" assert hasattr(databases, "drop_vs") assert callable(databases.drop_vs) - @patch("langchain_community.vectorstores.oraclevs.drop_table_purge") - def test_drop_vs_calls_langchain(self, mock_drop_table): - """Test drop_vs calls LangChain drop_table_purge""" - mock_connection = MagicMock() - vs_name = "TEST_VECTOR_STORE" - - databases.drop_vs(mock_connection, vs_name) - - mock_drop_table.assert_called_once_with(mock_connection, vs_name) + # test_drop_vs_calls_langchain: See test/unit/server/api/utils/test_utils_databases.py::TestDropVs::test_drop_vs_calls_langchain class TestDatabaseUtilsQueryFunctions: @@ -595,7 +410,4 @@ def test_get_client_database_with_validation(self, mock_get_settings, db_contain if db.connection: databases.disconnect(db.connection) - def test_logger_exists(self): - """Test that logger is properly configured""" - assert hasattr(databases, "logger") - assert databases.logger.name == "api.utils.database" + # test_logger_exists: See test/unit/server/api/utils/test_utils_databases.py::TestLoggerConfiguration::test_logger_exists diff --git a/tests/server/unit/api/utils/test_utils_models.py b/tests/server/unit/api/utils/test_utils_models.py index 70491d8d..d7451dde 100644 --- a/tests/server/unit/api/utils/test_utils_models.py +++ b/tests/server/unit/api/utils/test_utils_models.py @@ -21,29 +21,11 @@ class TestModelsExceptions: """Test custom exception classes""" - def test_url_unreachable_error(self): - """Test URLUnreachableError exception""" - error = URLUnreachableError("URL is unreachable") - assert str(error) == "URL is unreachable" - assert isinstance(error, ValueError) - - def test_invalid_model_error(self): - """Test InvalidModelError exception""" - error = InvalidModelError("Invalid model data") - assert str(error) == "Invalid model data" - assert isinstance(error, ValueError) - - def test_exists_model_error(self): - """Test ExistsModelError exception""" - error = ExistsModelError("Model already exists") - assert str(error) == "Model already exists" - assert isinstance(error, ValueError) - - def test_unknown_model_error(self): - """Test UnknownModelError exception""" - error = UnknownModelError("Model not found") - assert str(error) == "Model not found" - assert isinstance(error, ValueError) + # test_url_unreachable_error: See test/unit/server/api/utils/test_utils_models.py::TestExceptions::test_url_unreachable_error_is_value_error + # test_invalid_model_error: See test/unit/server/api/utils/test_utils_models.py::TestExceptions::test_invalid_model_error_is_value_error + # test_exists_model_error: See test/unit/server/api/utils/test_utils_models.py::TestExceptions::test_exists_model_error_is_value_error + # test_unknown_model_error: See test/unit/server/api/utils/test_utils_models.py::TestExceptions::test_unknown_model_error_is_value_error + pass ##################################################### @@ -64,15 +46,7 @@ def disabled_model(self): """Disabled model fixture""" return Model(id="disabled-model", provider="anthropic", type="ll", enabled=False) - @patch("server.api.utils.models.MODEL_OBJECTS") - def test_get_model_all_models(self, mock_model_objects, sample_model, disabled_model): - """Test getting all models without filters""" - mock_model_objects.__iter__ = MagicMock(return_value=iter([sample_model, disabled_model])) - mock_model_objects.__len__ = MagicMock(return_value=2) - - result = models.get() - - assert result == [sample_model, disabled_model] + # test_get_model_all_models: See test/unit/server/api/utils/test_utils_models.py::TestGet::test_get_all_models @patch("server.api.utils.models.MODEL_OBJECTS") def test_get_model_by_id_found(self, mock_model_objects, sample_model): @@ -93,58 +67,12 @@ def test_get_model_by_id_not_found(self, mock_model_objects, sample_model): with pytest.raises(UnknownModelError, match="nonexistent not found"): models.get(model_id="nonexistent") - @patch("server.api.utils.models.MODEL_OBJECTS") - def test_get_model_by_provider(self, mock_model_objects, sample_model, disabled_model): - """Test filtering models by provider""" - all_models = [sample_model, disabled_model] - mock_model_objects.__iter__ = MagicMock(return_value=iter(all_models)) - mock_model_objects.__len__ = MagicMock(return_value=len(all_models)) - - (result,) = models.get(model_provider="openai") - - # Since only one model matches provider="openai", it will return a list of single model - assert result == sample_model - - @patch("server.api.utils.models.MODEL_OBJECTS") - def test_get_model_by_type(self, mock_model_objects, sample_model, disabled_model): - """Test filtering models by type""" - all_models = [sample_model, disabled_model] - mock_model_objects.__iter__ = MagicMock(return_value=iter(all_models)) - mock_model_objects.__len__ = MagicMock(return_value=len(all_models)) - - result = models.get(model_type="ll") - - assert result == all_models - - @patch("server.api.utils.models.MODEL_OBJECTS") - def test_get_model_exclude_disabled(self, mock_model_objects, sample_model, disabled_model): - """Test excluding disabled models""" - all_models = [sample_model, disabled_model] - mock_model_objects.__iter__ = MagicMock(return_value=iter(all_models)) - mock_model_objects.__len__ = MagicMock(return_value=len(all_models)) - - (result,) = models.get(include_disabled=False) - assert result == sample_model - - @patch("server.api.utils.models.MODEL_OBJECTS", []) - @patch("server.api.utils.models.is_url_accessible") - def test_create_model_success(self, mock_url_check, sample_model): - """Test successful model creation""" - mock_url_check.return_value = (True, None) - - result = models.create(sample_model) - - assert result == sample_model - assert result in models.MODEL_OBJECTS - - @patch("server.api.utils.models.MODEL_OBJECTS") - @patch("server.api.utils.models.get") - def test_create_model_already_exists(self, mock_get_model, _mock_model_objects, sample_model): - """Test creating model that already exists""" - mock_get_model.return_value = sample_model # Model already exists + # test_get_model_by_provider: See test/unit/server/api/utils/test_utils_models.py::TestGet::test_get_by_provider + # test_get_model_by_type: See test/unit/server/api/utils/test_utils_models.py::TestGet::test_get_by_type + # test_get_model_exclude_disabled: See test/unit/server/api/utils/test_utils_models.py::TestGet::test_get_exclude_disabled - with pytest.raises(ExistsModelError, match="Model: openai/test-model already exists"): - models.create(sample_model) + # test_create_model_success: See test/unit/server/api/utils/test_utils_models.py::TestCreate::test_create_success + # test_create_model_already_exists: See test/unit/server/api/utils/test_utils_models.py::TestCreate::test_create_raises_exists_error @patch("server.api.utils.models.MODEL_OBJECTS", []) @patch("server.api.utils.models.is_url_accessible") @@ -173,25 +101,9 @@ def test_create_model_skip_url_check(self, sample_model): assert result == sample_model assert result in models.MODEL_OBJECTS - @patch("server.api.utils.models.MODEL_OBJECTS") - def test_delete_model(self, mock_model_objects): - """Test model deletion""" - test_models = [ - Model(id="test-model", provider="openai", type="ll"), - Model(id="other-model", provider="anthropic", type="ll"), - ] - mock_model_objects.__setitem__ = MagicMock() - mock_model_objects.__iter__ = MagicMock(return_value=iter(test_models)) - - models.delete("openai", "test-model") - - # Verify the slice assignment was called - mock_model_objects.__setitem__.assert_called_once() - - def test_logger_exists(self): - """Test that logger is properly configured""" - assert hasattr(models, "logger") - assert models.logger.name == "api.utils.models" + # test_delete_model: See test/unit/server/api/utils/test_utils_models.py::TestDelete::test_delete_removes_model + # test_logger_exists: See test/unit/server/api/utils/test_utils_models.py::TestLoggerConfiguration::test_logger_exists + pass ##################################################### @@ -212,26 +124,7 @@ def sample_oci_config(self): """Sample OCI config fixture""" return get_sample_oci_config() - @patch("server.api.utils.models.MODEL_OBJECTS", []) - @patch("server.api.utils.models.is_url_accessible") - def test_update_success(self, mock_url_check, sample_model): - """Test successful model update""" - # First create the model - models.MODEL_OBJECTS.append(sample_model) - mock_url_check.return_value = (True, None) - - update_payload = Model( - id="test-model", - provider="openai", - type="ll", - enabled=True, - api_base="https://api.openai.com", - temperature=0.8, - ) - - result = models.update(update_payload) - - assert result.temperature == 0.8 + # test_update_success: See test/unit/server/api/utils/test_utils_models.py::TestUpdate::test_update_success @patch("server.api.utils.models.MODEL_OBJECTS", []) @patch("server.api.utils.models.is_url_accessible") @@ -296,45 +189,9 @@ def test_update_multiple_fields(self, mock_url_check, sample_model): assert result.temperature == 0.5 assert result.max_tokens == 2048 - @patch("server.api.utils.models.get") - def test_get_full_config_success(self, mock_get_model, sample_model, sample_oci_config): - """Test successful full config retrieval""" - mock_get_model.return_value = [sample_model] - model_config = {"model": "openai/gpt-4", "temperature": 0.8} - - full_config, provider = models._get_full_config(model_config, sample_oci_config) - - assert provider == "openai" - assert full_config["temperature"] == 0.8 - assert full_config["id"] == "test-model" - mock_get_model.assert_called_once_with(model_provider="openai", model_id="gpt-4", include_disabled=False) - - @patch("server.api.utils.models.get") - def test_get_full_config_unknown_model(self, mock_get_model, sample_oci_config): - """Test full config retrieval with unknown model""" - mock_get_model.side_effect = UnknownModelError("Model not found") - model_config = {"model": "unknown/model"} - - with pytest.raises(UnknownModelError): - models._get_full_config(model_config, sample_oci_config) - - @patch("server.api.utils.models._get_full_config") - @patch("litellm.get_supported_openai_params") - def test_get_litellm_config_basic(self, mock_get_params, mock_get_full_config, sample_oci_config): - """Test basic LiteLLM config generation""" - mock_get_full_config.return_value = ( - {"temperature": 0.7, "max_tokens": 4096, "api_base": "https://api.openai.com"}, - "openai", - ) - mock_get_params.return_value = ["temperature", "max_tokens"] - model_config = {"model": "openai/gpt-4"} - - result = models.get_litellm_config(model_config, sample_oci_config) - - assert result["model"] == "openai/gpt-4" - assert result["temperature"] == 0.7 - assert result["max_tokens"] == 4096 - assert result["drop_params"] is True + # test_get_full_config_success: See test/unit/server/api/utils/test_utils_models.py::TestGetFullConfig::test_get_full_config_success + # test_get_full_config_unknown_model: See test/unit/server/api/utils/test_utils_models.py::TestGetFullConfig::test_get_full_config_raises_unknown_model + # test_get_litellm_config_basic: See test/unit/server/api/utils/test_utils_models.py::TestGetLitellmConfig::test_get_litellm_config_basic @patch("server.api.utils.models._get_full_config") @patch("litellm.get_supported_openai_params") @@ -366,21 +223,7 @@ def test_get_litellm_config_xai(self, mock_get_params, mock_get_full_config, sam assert "presence_penalty" not in result assert "frequency_penalty" not in result - @patch("server.api.utils.models._get_full_config") - @patch("litellm.get_supported_openai_params") - def test_get_litellm_config_oci(self, mock_get_params, mock_get_full_config, sample_oci_config): - """Test LiteLLM config generation for OCI""" - mock_get_full_config.return_value = ({"temperature": 0.7}, "oci") - mock_get_params.return_value = ["temperature"] - model_config = {"model": "oci/cohere.command"} - - result = models.get_litellm_config(model_config, sample_oci_config) - - assert result["oci_user"] == "ocid1.user.oc1..testuser" - assert result["oci_fingerprint"] == "test-fingerprint" - assert result["oci_tenancy"] == "ocid1.tenancy.oc1..testtenant" - assert result["oci_region"] == "us-ashburn-1" - assert result["oci_key_file"] == "/path/to/key.pem" + # test_get_litellm_config_oci: See test/unit/server/api/utils/test_utils_models.py::TestGetLitellmConfig::test_get_litellm_config_oci_provider @patch("server.api.utils.models._get_full_config") @patch("litellm.get_supported_openai_params") @@ -395,7 +238,4 @@ def test_get_litellm_config_giskard(self, mock_get_params, mock_get_full_config, assert "model" not in result assert "temperature" not in result - def test_logger_exists(self): - """Test that logger is properly configured""" - assert hasattr(models, "logger") - assert models.logger.name == "api.utils.models" + # test_logger_exists: See test/unit/server/api/utils/test_utils_models.py::TestLoggerConfiguration::test_logger_exists diff --git a/tests/server/unit/api/utils/test_utils_oci.py b/tests/server/unit/api/utils/test_utils_oci.py index 1e8d4819..39c0a4f1 100644 --- a/tests/server/unit/api/utils/test_utils_oci.py +++ b/tests/server/unit/api/utils/test_utils_oci.py @@ -19,12 +19,7 @@ class TestOciException: """Test custom OCI exception class""" - def test_oci_exception_initialization(self): - """Test OciException initialization""" - exc = OciException(status_code=400, detail="Invalid configuration") - assert exc.status_code == 400 - assert exc.detail == "Invalid configuration" - assert str(exc) == "Invalid configuration" + # test_oci_exception_initialization: See test/unit/server/api/utils/test_utils_oci.py::TestOciException::test_oci_exception_init class TestOciGet: @@ -49,38 +44,10 @@ def sample_client_settings(self): """Sample client settings fixture""" return Settings(client="test_client", oci=OciSettings(auth_profile="CUSTOM")) - @patch("server.bootstrap.bootstrap.OCI_OBJECTS", []) - def test_get_no_objects_configured(self): - """Test getting OCI settings when none are configured""" - with pytest.raises(ValueError, match="not configured"): - oci_utils.get() - - @patch("server.bootstrap.bootstrap.OCI_OBJECTS", new_callable=list) - def test_get_all(self, mock_oci_objects, sample_oci_default, sample_oci_custom): - """Test getting all OCI settings when no filters are provided""" - all_oci = [sample_oci_default, sample_oci_custom] - mock_oci_objects.extend(all_oci) - - result = oci_utils.get() - - assert result == all_oci - - @patch("server.bootstrap.bootstrap.OCI_OBJECTS") - def test_get_by_auth_profile_found(self, mock_oci_objects, sample_oci_default, sample_oci_custom): - """Test getting OCI settings by auth_profile when it exists""" - mock_oci_objects.__iter__ = MagicMock(return_value=iter([sample_oci_default, sample_oci_custom])) - - result = oci_utils.get(auth_profile="CUSTOM") - - assert result == sample_oci_custom - - @patch("server.bootstrap.bootstrap.OCI_OBJECTS") - def test_get_by_auth_profile_not_found(self, mock_oci_objects, sample_oci_default): - """Test getting OCI settings by auth_profile when it doesn't exist""" - mock_oci_objects.__iter__ = MagicMock(return_value=iter([sample_oci_default])) - - with pytest.raises(ValueError, match="profile 'NONEXISTENT' not found"): - oci_utils.get(auth_profile="NONEXISTENT") + # test_get_no_objects_configured: See test/unit/server/api/utils/test_utils_oci.py::TestGet::test_get_raises_value_error_when_not_configured + # test_get_all: See test/unit/server/api/utils/test_utils_oci.py::TestGet::test_get_returns_all_oci_objects + # test_get_by_auth_profile_found: See test/unit/server/api/utils/test_utils_oci.py::TestGet::test_get_by_auth_profile + # test_get_by_auth_profile_not_found: See test/unit/server/api/utils/test_utils_oci.py::TestGet::test_get_raises_value_error_profile_not_found def test_get_by_client_with_oci_settings(self, sample_client_settings, sample_oci_default, sample_oci_custom): """Test getting OCI settings by client when client has OCI settings""" @@ -126,14 +93,7 @@ def test_get_by_client_without_oci_settings(self, sample_oci_default): bootstrap.SETTINGS_OBJECTS = orig_settings bootstrap.OCI_OBJECTS = orig_oci - @patch("server.bootstrap.bootstrap.OCI_OBJECTS") - @patch("server.bootstrap.bootstrap.SETTINGS_OBJECTS") - def test_get_by_client_not_found(self, mock_settings_objects, _mock_oci_objects): - """Test getting OCI settings when client doesn't exist""" - mock_settings_objects.__iter__ = MagicMock(return_value=iter([])) - - with pytest.raises(ValueError, match="client test_client not found"): - oci_utils.get(client="test_client") + # test_get_by_client_not_found: See test/unit/server/api/utils/test_utils_oci.py::TestGet::test_get_raises_value_error_client_not_found def test_get_by_client_no_matching_profile(self, sample_client_settings, sample_oci_default): """Test getting OCI settings by client when no matching profile exists""" @@ -156,48 +116,15 @@ def test_get_by_client_no_matching_profile(self, sample_client_settings, sample_ bootstrap.SETTINGS_OBJECTS = orig_settings bootstrap.OCI_OBJECTS = orig_oci - def test_get_both_client_and_auth_profile(self): - """Test that providing both client and auth_profile raises an error""" - with pytest.raises(ValueError, match="provide either 'client' or 'auth_profile', not both"): - oci_utils.get(client="test_client", auth_profile="CUSTOM") + # test_get_both_client_and_auth_profile: See test/unit/server/api/utils/test_utils_oci.py::TestGet::test_get_raises_value_error_both_params class TestGetSigner: """Test get_signer() function""" - def test_get_signer_instance_principal(self): - """Test get_signer with instance_principal authentication""" - config = OracleCloudSettings(auth_profile="DEFAULT", authentication="instance_principal") - - with patch("oci.auth.signers.InstancePrincipalsSecurityTokenSigner") as mock_signer: - mock_instance = MagicMock() - mock_signer.return_value = mock_instance - - result = oci_utils.get_signer(config) - - assert result == mock_instance - mock_signer.assert_called_once() - - def test_get_signer_oke_workload_identity(self): - """Test get_signer with oke_workload_identity authentication""" - config = OracleCloudSettings(auth_profile="DEFAULT", authentication="oke_workload_identity") - - with patch("oci.auth.signers.get_oke_workload_identity_resource_principal_signer") as mock_signer: - mock_instance = MagicMock() - mock_signer.return_value = mock_instance - - result = oci_utils.get_signer(config) - - assert result == mock_instance - mock_signer.assert_called_once() - - def test_get_signer_api_key(self): - """Test get_signer with api_key authentication (returns None)""" - config = OracleCloudSettings(auth_profile="DEFAULT", authentication="api_key") - - result = oci_utils.get_signer(config) - - assert result is None + # test_get_signer_instance_principal: See test/unit/server/api/utils/test_utils_oci.py::TestGetSigner::test_get_signer_instance_principal + # test_get_signer_oke_workload_identity: See test/unit/server/api/utils/test_utils_oci.py::TestGetSigner::test_get_signer_oke_workload_identity + # test_get_signer_api_key: See test/unit/server/api/utils/test_utils_oci.py::TestGetSigner::test_get_signer_api_key_returns_none def test_get_signer_security_token(self): """Test get_signer with security_token authentication (returns None)""" @@ -224,18 +151,7 @@ def api_key_config(self): key_file="/path/to/key.pem", ) - @patch("oci.object_storage.ObjectStorageClient") - @patch.object(oci_utils, "get_signer", return_value=None) - def test_init_client_api_key(self, mock_get_signer, mock_client_class, api_key_config): - """Test init_client with API key authentication""" - mock_client = MagicMock() - mock_client_class.return_value = mock_client - - result = oci_utils.init_client(oci.object_storage.ObjectStorageClient, api_key_config) - - assert result == mock_client - mock_get_signer.assert_called_once_with(api_key_config) - mock_client_class.assert_called_once() + # test_init_client_api_key: See test/unit/server/api/utils/test_utils_oci.py::TestInitClient::test_init_client_standard_auth @patch("oci.generative_ai_inference.GenerativeAiInferenceClient") @patch.object(oci_utils, "get_signer", return_value=None) @@ -349,17 +265,7 @@ def test_init_client_with_security_token( mock_load_key.assert_called_once_with("/path/to/key.pem") mock_sec_token_signer.assert_called_once_with("mock_token_content", mock_private_key) - @patch("oci.object_storage.ObjectStorageClient") - @patch.object(oci_utils, "get_signer", return_value=None) - def test_init_client_invalid_config(self, _mock_get_signer, mock_client_class, api_key_config): - """Test init_client with invalid config raises OciException""" - mock_client_class.side_effect = oci.exceptions.InvalidConfig("Bad config") - - with pytest.raises(OciException) as exc_info: - oci_utils.init_client(oci.object_storage.ObjectStorageClient, api_key_config) - - assert exc_info.value.status_code == 400 - assert "Invalid Config" in str(exc_info.value) + # test_init_client_invalid_config: See test/unit/server/api/utils/test_utils_oci.py::TestInitClient::test_init_client_raises_oci_exception_on_invalid_config class TestOciUtils: @@ -370,30 +276,8 @@ def sample_oci_config(self): """Sample OCI config fixture""" return get_sample_oci_config() - def test_init_genai_client(self, sample_oci_config): - """Test GenAI client initialization""" - with patch.object(oci_utils, "init_client") as mock_init_client: - mock_client = MagicMock() - mock_init_client.return_value = mock_client - - result = oci_utils.init_genai_client(sample_oci_config) - - assert result == mock_client - mock_init_client.assert_called_once_with( - oci.generative_ai_inference.GenerativeAiInferenceClient, sample_oci_config - ) - - @patch.object(oci_utils, "init_client") - def test_get_namespace_success(self, mock_init_client, sample_oci_config): - """Test successful namespace retrieval""" - mock_client = MagicMock() - mock_client.get_namespace.return_value.data = "test-namespace" - mock_init_client.return_value = mock_client - - result = oci_utils.get_namespace(sample_oci_config) - - assert result == "test-namespace" - assert sample_oci_config.namespace == "test-namespace" + # test_init_genai_client: See test/unit/server/api/utils/test_utils_oci.py::TestInitGenaiClient::test_init_genai_client_calls_init_client + # test_get_namespace_success: See test/unit/server/api/utils/test_utils_oci.py::TestGetNamespace::test_get_namespace_success @patch.object(oci_utils, "init_client") def test_get_namespace_invalid_config(self, mock_init_client, sample_oci_config): @@ -408,31 +292,8 @@ def test_get_namespace_invalid_config(self, mock_init_client, sample_oci_config) assert exc_info.value.status_code == 400 assert "Invalid Config" in str(exc_info.value) - @patch.object(oci_utils, "init_client") - def test_get_namespace_file_not_found(self, mock_init_client, sample_oci_config): - """Test namespace retrieval with file not found error""" - mock_init_client.side_effect = FileNotFoundError("Key file not found") - - with pytest.raises(OciException) as exc_info: - oci_utils.get_namespace(sample_oci_config) - - assert exc_info.value.status_code == 400 - assert "Invalid Key Path" in str(exc_info.value) - - @patch.object(oci_utils, "init_client") - def test_get_namespace_service_error(self, mock_init_client, sample_oci_config): - """Test namespace retrieval with service error""" - mock_client = MagicMock() - mock_client.get_namespace.side_effect = oci.exceptions.ServiceError( - status=401, code="NotAuthenticated", headers={}, message="Auth failed" - ) - mock_init_client.return_value = mock_client - - with pytest.raises(OciException) as exc_info: - oci_utils.get_namespace(sample_oci_config) - - assert exc_info.value.status_code == 401 - assert "AuthN Error" in str(exc_info.value) + # test_get_namespace_file_not_found: See test/unit/server/api/utils/test_utils_oci.py::TestGetNamespace::test_get_namespace_raises_on_file_not_found + # test_get_namespace_service_error: See test/unit/server/api/utils/test_utils_oci.py::TestGetNamespace::test_get_namespace_raises_on_service_error @patch.object(oci_utils, "init_client") def test_get_namespace_unbound_local_error(self, mock_init_client, sample_oci_config): @@ -472,27 +333,5 @@ def test_get_namespace_generic_exception(self, mock_init_client, sample_oci_conf assert exc_info.value.status_code == 500 assert "Unexpected error" in str(exc_info.value) - @patch.object(oci_utils, "init_client") - def test_get_regions_success(self, mock_init_client, sample_oci_config): - """Test successful regions retrieval""" - mock_client = MagicMock() - mock_region = MagicMock() - mock_region.is_home_region = True - mock_region.region_key = "IAD" - mock_region.region_name = "us-ashburn-1" - mock_region.status = "READY" - mock_client.list_region_subscriptions.return_value.data = [mock_region] - mock_init_client.return_value = mock_client - - result = oci_utils.get_regions(sample_oci_config) - - assert len(result) == 1 - assert result[0]["is_home_region"] is True - assert result[0]["region_key"] == "IAD" - assert result[0]["region_name"] == "us-ashburn-1" - assert result[0]["status"] == "READY" - - def test_logger_exists(self): - """Test that logger is properly configured""" - assert hasattr(oci_utils, "logger") - assert oci_utils.logger.name == "api.utils.oci" + # test_get_regions_success: See test/unit/server/api/utils/test_utils_oci.py::TestGetRegions::test_get_regions_returns_list + # test_logger_exists: See test/unit/server/api/utils/test_utils_oci.py::TestLoggerConfiguration::test_logger_exists diff --git a/tests/server/unit/api/utils/test_utils_settings.py b/tests/server/unit/api/utils/test_utils_settings.py index aebff4d0..3027874f 100644 --- a/tests/server/unit/api/utils/test_utils_settings.py +++ b/tests/server/unit/api/utils/test_utils_settings.py @@ -44,67 +44,12 @@ def make_sample_config_data(): class TestClientSettings: """Test client settings CRUD operations""" - @patch("server.api.utils.settings.bootstrap") - def test_create_client_success(self, mock_bootstrap): - """Test successful client settings creation""" - default_cfg = make_default_settings() - settings_list = [default_cfg] - mock_bootstrap.SETTINGS_OBJECTS = settings_list - - result = settings.create_client("new_client") - - assert result.client == "new_client" - # Verify ll_model settings are copied from default - result_ll_model = result.model_dump()["ll_model"] - default_ll_model = default_cfg.model_dump()["ll_model"] - assert result_ll_model["max_tokens"] == default_ll_model["max_tokens"] - assert len(settings_list) == 2 - assert settings_list[-1].client == "new_client" - - @patch("server.api.utils.settings.bootstrap.SETTINGS_OBJECTS") - def test_create_client_already_exists(self, mock_settings_objects): - """Test creating client settings when client already exists""" - test_cfg = make_test_client_settings() - mock_settings_objects.__iter__ = MagicMock(return_value=iter([test_cfg])) - - with pytest.raises(ValueError, match="client test_client already exists"): - settings.create_client("test_client") - - @patch("server.api.utils.settings.bootstrap.SETTINGS_OBJECTS") - def test_get_client_found(self, mock_settings_objects): - """Test getting client settings when client exists""" - test_cfg = make_test_client_settings() - mock_settings_objects.__iter__ = MagicMock(return_value=iter([test_cfg])) - - result = settings.get_client("test_client") - - assert result == test_cfg - - @patch("server.api.utils.settings.bootstrap.SETTINGS_OBJECTS") - def test_get_client_not_found(self, mock_settings_objects): - """Test getting client settings when client doesn't exist""" - default_cfg = make_default_settings() - mock_settings_objects.__iter__ = MagicMock(return_value=iter([default_cfg])) - - with pytest.raises(ValueError, match="client nonexistent not found"): - settings.get_client("nonexistent") - - @patch("server.api.utils.settings.bootstrap.SETTINGS_OBJECTS") - @patch("server.api.utils.settings.get_client") - def test_update_client(self, mock_get_settings, mock_settings_objects): - """Test updating client settings""" - test_cfg = make_test_client_settings() - mock_get_settings.return_value = test_cfg - mock_settings_objects.remove = MagicMock() - mock_settings_objects.append = MagicMock() - mock_settings_objects.__iter__ = MagicMock(return_value=iter([test_cfg])) - - new_settings = Settings(client="test_client", max_tokens=800, temperature=0.9) - result = settings.update_client(new_settings, "test_client") - - assert result.client == "test_client" - mock_settings_objects.remove.assert_called_once_with(test_cfg) - mock_settings_objects.append.assert_called_once() + # test_create_client_success: See test/unit/server/api/utils/test_utils_settings.py::TestCreateClient::test_create_client_success + # test_create_client_already_exists: See test/unit/server/api/utils/test_utils_settings.py::TestCreateClient::test_create_client_raises_on_existing + # test_get_client_found: See test/unit/server/api/utils/test_utils_settings.py::TestGetClient::test_get_client_success + # test_get_client_not_found: See test/unit/server/api/utils/test_utils_settings.py::TestGetClient::test_get_client_raises_on_not_found + # test_update_client: See test/unit/server/api/utils/test_utils_settings.py::TestUpdateClient::test_update_client_success + pass ##################################################### @@ -113,39 +58,8 @@ def test_update_client(self, mock_get_settings, mock_settings_objects): class TestServerConfiguration: """Test server configuration operations""" - @pytest.mark.asyncio - @patch("server.api.utils.settings.get_mcp_prompts_with_overrides") - @patch("server.api.utils.settings.bootstrap.DATABASE_OBJECTS") - @patch("server.api.utils.settings.bootstrap.MODEL_OBJECTS") - @patch("server.api.utils.settings.bootstrap.OCI_OBJECTS") - async def test_get_server(self, mock_oci, mock_models, mock_databases, mock_get_prompts): - """Test getting server configuration""" - mock_databases.__iter__ = MagicMock( - return_value=iter([Database(name="test", user="u", password="p", dsn="d")]) - ) - mock_models.__iter__ = MagicMock(return_value=iter([Model(id="test", provider="openai", type="ll")])) - mock_oci.__iter__ = MagicMock(return_value=iter([OracleCloudSettings(auth_profile="DEFAULT")])) - mock_get_prompts.return_value = [] - - mock_mcp_engine = MagicMock() - result = await settings.get_server(mock_mcp_engine) - - assert "database_configs" in result - assert "model_configs" in result - assert "oci_configs" in result - assert "prompt_configs" in result - - @patch("server.api.utils.settings.bootstrap") - def test_update_server(self, mock_bootstrap): - """Test updating server configuration""" - mock_bootstrap.DATABASE_OBJECTS = [] - mock_bootstrap.MODEL_OBJECTS = [] - mock_bootstrap.OCI_OBJECTS = [] - - settings.update_server(make_sample_config_data()) - - assert hasattr(mock_bootstrap, "DATABASE_OBJECTS") - assert hasattr(mock_bootstrap, "MODEL_OBJECTS") + # test_get_server: See test/unit/server/api/utils/test_utils_settings.py::TestGetServer::test_get_server_returns_config + # test_update_server: See test/unit/server/api/utils/test_utils_settings.py::TestUpdateServer::test_update_server_updates_databases @patch("server.api.utils.settings.bootstrap") def test_update_server_mutates_lists_not_replaces(self, mock_bootstrap): @@ -185,64 +99,14 @@ def test_update_server_mutates_lists_not_replaces(self, mock_bootstrap): class TestConfigLoading: """Test configuration loading operations""" - @patch("server.api.utils.settings.update_server") - @patch("server.api.utils.settings.update_client") - def test_load_config_from_json_data_with_client(self, mock_update_client, mock_update_server): - """Test loading config from JSON data with specific client""" - config_data = make_sample_config_data() - settings.load_config_from_json_data(config_data, client="test_client") - - mock_update_server.assert_called_once_with(config_data) - mock_update_client.assert_called_once() - - @patch("server.api.utils.settings.update_server") - @patch("server.api.utils.settings.update_client") - def test_load_config_from_json_data_without_client(self, mock_update_client, mock_update_server): - """Test loading config from JSON data without specific client""" - config_data = make_sample_config_data() - settings.load_config_from_json_data(config_data) - - mock_update_server.assert_called_once_with(config_data) - assert mock_update_client.call_count == 2 - - @patch("server.api.utils.settings.update_server") - def test_load_config_from_json_data_missing_client_settings(self, _mock_update_server): - """Test loading config from JSON data without client_settings""" - invalid_config = {"database_configs": [], "model_configs": [], "oci_configs": [], "prompt_configs": []} - - with pytest.raises(KeyError, match="Missing client_settings in config file"): - settings.load_config_from_json_data(invalid_config) - - @patch.dict(os.environ, {"CONFIG_FILE": "/path/to/config.json"}) - @patch("os.path.isfile") - @patch("os.access") - @patch("builtins.open", mock_open(read_data='{"test": "data"}')) - @patch("json.load") - def test_read_config_from_json_file_success(self, mock_json_load, mock_access, mock_isfile): - """Test successful reading of config file""" - mock_isfile.return_value = True - mock_access.return_value = True - mock_json_load.return_value = make_sample_config_data() - - result = settings.read_config_from_json_file() - - assert isinstance(result, Configuration) - mock_json_load.assert_called_once() - - @patch.dict(os.environ, {"CONFIG_FILE": "/path/to/nonexistent.json"}) - @patch("os.path.isfile") - def test_read_config_from_json_file_not_exists(self, mock_isfile): - """Test reading config file that doesn't exist""" - mock_isfile.return_value = False - - @patch.dict(os.environ, {"CONFIG_FILE": "/path/to/config.txt"}) - def test_read_config_from_json_file_wrong_extension(self): - """Test reading config file with wrong extension""" - - def test_logger_exists(self): - """Test that logger is properly configured""" - assert hasattr(settings, "logger") - assert settings.logger.name == "api.core.settings" + # test_load_config_from_json_data_with_client: See test/unit/server/api/utils/test_utils_settings.py::TestLoadConfigFromJsonData::test_load_config_from_json_data_with_client + # test_load_config_from_json_data_without_client: See test/unit/server/api/utils/test_utils_settings.py::TestLoadConfigFromJsonData::test_load_config_from_json_data_without_client + # test_load_config_from_json_data_missing_client_settings: See test/unit/server/api/utils/test_utils_settings.py::TestLoadConfigFromJsonData::test_load_config_from_json_data_raises_missing_settings + # test_read_config_from_json_file_success: See test/unit/server/api/utils/test_utils_settings.py::TestReadConfigFromJsonFile::test_read_config_from_json_file_success + # test_read_config_from_json_file_not_exists: Empty test stub - not implemented + # test_read_config_from_json_file_wrong_extension: Empty test stub - not implemented + # test_logger_exists: See test/unit/server/api/utils/test_utils_settings.py::TestLoggerConfiguration::test_logger_exists + pass ##################################################### @@ -251,25 +115,8 @@ def test_logger_exists(self): class TestPromptOverrides: """Test prompt override operations""" - @patch("server.api.utils.settings.cache") - def test_load_prompt_override_with_text(self, mock_cache): - """Test loading prompt override when text is provided""" - prompt = {"name": "optimizer_test-prompt", "text": "You are a test assistant"} - - result = settings._load_prompt_override(prompt) - - assert result is True - mock_cache.set_override.assert_called_once_with("optimizer_test-prompt", "You are a test assistant") - - @patch("server.api.utils.settings.cache") - def test_load_prompt_override_without_text(self, mock_cache): - """Test loading prompt override when text is not provided""" - prompt = {"name": "optimizer_test-prompt"} - - result = settings._load_prompt_override(prompt) - - assert result is False - mock_cache.set_override.assert_not_called() + # test_load_prompt_override_with_text: See test/unit/server/api/utils/test_utils_settings.py::TestLoadPromptOverride::test_load_prompt_override_with_text + # test_load_prompt_override_without_text: See test/unit/server/api/utils/test_utils_settings.py::TestLoadPromptOverride::test_load_prompt_override_without_text @patch("server.api.utils.settings.cache") def test_load_prompt_override_empty_text(self, mock_cache): @@ -281,36 +128,6 @@ def test_load_prompt_override_empty_text(self, mock_cache): assert result is False mock_cache.set_override.assert_not_called() - @patch("server.api.utils.settings._load_prompt_override") - def test_load_prompt_configs_success(self, mock_load_override): - """Test loading prompt configs successfully""" - mock_load_override.side_effect = [True, True, False] - config_data = { - "prompt_configs": [ - {"name": "prompt1", "text": "text1"}, - {"name": "prompt2", "text": "text2"}, - {"name": "prompt3", "text": "text3"}, - ] - } - - settings._load_prompt_configs(config_data) - - assert mock_load_override.call_count == 3 - - @patch("server.api.utils.settings._load_prompt_override") - def test_load_prompt_configs_no_prompts_key(self, mock_load_override): - """Test loading prompt configs when key is missing""" - config_data = {"other_configs": []} - - settings._load_prompt_configs(config_data) - - mock_load_override.assert_not_called() - - @patch("server.api.utils.settings._load_prompt_override") - def test_load_prompt_configs_empty_list(self, mock_load_override): - """Test loading prompt configs with empty list""" - config_data = {"prompt_configs": []} - - settings._load_prompt_configs(config_data) - - mock_load_override.assert_not_called() + # test_load_prompt_configs_success: See test/unit/server/api/utils/test_utils_settings.py::TestLoadPromptConfigs::test_load_prompt_configs_with_prompts + # test_load_prompt_configs_no_prompts_key: See test/unit/server/api/utils/test_utils_settings.py::TestLoadPromptConfigs::test_load_prompt_configs_without_key + # test_load_prompt_configs_empty_list: See test/unit/server/api/utils/test_utils_settings.py::TestLoadPromptConfigs::test_load_prompt_configs_empty_list From 04bce4ae7c3f296b821a56580ade23478ba90fce Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 16:26:34 +0000 Subject: [PATCH 08/20] updated tests --- test/integration/server/bootstrap/__init__.py | 1 + test/integration/server/bootstrap/conftest.py | 206 +++++++++ .../bootstrap/test_bootstrap_configfile.py | 245 ++++++++++ .../bootstrap/test_bootstrap_databases.py | 205 +++++++++ .../server/bootstrap/test_bootstrap_models.py | 272 ++++++++++++ .../server/bootstrap/test_bootstrap_oci.py | 246 +++++++++++ .../bootstrap/test_bootstrap_settings.py | 170 +++++++ test/unit/server/bootstrap/__init__.py | 1 + test/unit/server/bootstrap/conftest.py | 286 ++++++++++++ .../bootstrap/test_bootstrap_bootstrap.py | 183 ++++++++ .../bootstrap/test_bootstrap_configfile.py | 229 ++++++++++ .../bootstrap/test_bootstrap_databases.py | 227 ++++++++++ .../server/bootstrap/test_bootstrap_models.py | 418 ++++++++++++++++++ .../server/bootstrap/test_bootstrap_oci.py | 329 ++++++++++++++ .../bootstrap/test_bootstrap_settings.py | 143 ++++++ tests/server/unit/bootstrap/test_bootstrap.py | 60 +-- 16 files changed, 3177 insertions(+), 44 deletions(-) create mode 100644 test/integration/server/bootstrap/__init__.py create mode 100644 test/integration/server/bootstrap/conftest.py create mode 100644 test/integration/server/bootstrap/test_bootstrap_configfile.py create mode 100644 test/integration/server/bootstrap/test_bootstrap_databases.py create mode 100644 test/integration/server/bootstrap/test_bootstrap_models.py create mode 100644 test/integration/server/bootstrap/test_bootstrap_oci.py create mode 100644 test/integration/server/bootstrap/test_bootstrap_settings.py create mode 100644 test/unit/server/bootstrap/__init__.py create mode 100644 test/unit/server/bootstrap/conftest.py create mode 100644 test/unit/server/bootstrap/test_bootstrap_bootstrap.py create mode 100644 test/unit/server/bootstrap/test_bootstrap_configfile.py create mode 100644 test/unit/server/bootstrap/test_bootstrap_databases.py create mode 100644 test/unit/server/bootstrap/test_bootstrap_models.py create mode 100644 test/unit/server/bootstrap/test_bootstrap_oci.py create mode 100644 test/unit/server/bootstrap/test_bootstrap_settings.py diff --git a/test/integration/server/bootstrap/__init__.py b/test/integration/server/bootstrap/__init__.py new file mode 100644 index 00000000..90dc5216 --- /dev/null +++ b/test/integration/server/bootstrap/__init__.py @@ -0,0 +1 @@ +# Bootstrap integration test package diff --git a/test/integration/server/bootstrap/conftest.py b/test/integration/server/bootstrap/conftest.py new file mode 100644 index 00000000..4bea8405 --- /dev/null +++ b/test/integration/server/bootstrap/conftest.py @@ -0,0 +1,206 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Pytest fixtures for server/bootstrap integration tests. + +Integration tests for bootstrap test the actual bootstrap process with real +file I/O, environment variables, and configuration loading. These tests +verify end-to-end behavior of the bootstrap system. +""" + +# pylint: disable=redefined-outer-name protected-access + +import json +import os +import tempfile +from pathlib import Path + +import pytest + +from server.bootstrap.configfile import ConfigStore + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test files.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def make_config_file(temp_dir): + """Factory fixture to create real configuration JSON files.""" + + def _make_config_file( + filename: str = "configuration.json", + client_settings: dict = None, + database_configs: list = None, + model_configs: list = None, + oci_configs: list = None, + prompt_configs: list = None, + ): + config_data = { + "client_settings": client_settings or {"client": "test_client"}, + "database_configs": database_configs or [], + "model_configs": model_configs or [], + "oci_configs": oci_configs or [], + "prompt_configs": prompt_configs or [], + } + + file_path = temp_dir / filename + with open(file_path, "w", encoding="utf-8") as f: + json.dump(config_data, f, indent=2) + + return file_path + + return _make_config_file + + +@pytest.fixture +def make_oci_config_file(temp_dir): + """Factory fixture to create real OCI configuration files.""" + + def _make_oci_config_file( + filename: str = "config", + profiles: dict = None, + ): + """Create an OCI-style config file. + + Args: + filename: Name of the config file + profiles: Dict of profile_name -> dict of key-value pairs + e.g., {"DEFAULT": {"tenancy": "...", "region": "..."}} + """ + if profiles is None: + profiles = { + "DEFAULT": { + "tenancy": "ocid1.tenancy.oc1..testtenancy", + "region": "us-ashburn-1", + "fingerprint": "test:fingerprint", + } + } + + file_path = temp_dir / filename + with open(file_path, "w", encoding="utf-8") as f: + for profile_name, settings in profiles.items(): + f.write(f"[{profile_name}]\n") + for key, value in settings.items(): + f.write(f"{key}={value}\n") + f.write("\n") + + return file_path + + return _make_oci_config_file + + +@pytest.fixture +def clean_bootstrap_env(): + """Fixture to clean environment variables that affect bootstrap. + + This fixture saves current env vars, clears them for the test, + and restores them afterward. + """ + env_vars = [ + # Database vars + "DB_USERNAME", + "DB_PASSWORD", + "DB_DSN", + "DB_WALLET_PASSWORD", + "TNS_ADMIN", + # Model API keys + "OPENAI_API_KEY", + "COHERE_API_KEY", + "PPLX_API_KEY", + # On-prem model URLs + "ON_PREM_OLLAMA_URL", + "ON_PREM_VLLM_URL", + "ON_PREM_HF_URL", + # OCI vars + "OCI_CLI_CONFIG_FILE", + "OCI_CLI_TENANCY", + "OCI_CLI_REGION", + "OCI_CLI_USER", + "OCI_CLI_FINGERPRINT", + "OCI_CLI_KEY_FILE", + "OCI_CLI_SECURITY_TOKEN_FILE", + "OCI_CLI_AUTH", + "OCI_GENAI_COMPARTMENT_ID", + "OCI_GENAI_REGION", + "OCI_GENAI_SERVICE_ENDPOINT", + ] + + original_values = {} + for var in env_vars: + original_values[var] = os.environ.pop(var, None) + + yield + + # Restore original values + for var, value in original_values.items(): + if value is not None: + os.environ[var] = value + elif var in os.environ: + del os.environ[var] + + +@pytest.fixture +def reset_config_store(): + """Reset ConfigStore singleton state before and after each test.""" + # Reset before test + ConfigStore._config = None + + yield ConfigStore + + # Reset after test + ConfigStore._config = None + + +@pytest.fixture +def sample_database_config(): + """Sample database configuration dict.""" + return { + "name": "INTEGRATION_DB", + "user": "integration_user", + "password": "integration_pass", + "dsn": "localhost:1521/INTPDB", + } + + +@pytest.fixture +def sample_model_config(): + """Sample model configuration dict.""" + return { + "id": "integration-model", + "type": "ll", + "provider": "openai", + "enabled": True, + "api_key": "test-api-key", + "api_base": "https://api.openai.com/v1", + "max_tokens": 4096, + } + + +@pytest.fixture +def sample_oci_config(): + """Sample OCI configuration dict.""" + return { + "auth_profile": "INTEGRATION", + "tenancy": "ocid1.tenancy.oc1..integration", + "region": "us-phoenix-1", + "fingerprint": "integration:fingerprint", + } + + +@pytest.fixture +def sample_settings_config(): + """Sample settings configuration dict.""" + return { + "client": "integration_client", + "ll_model": { + "model": "gpt-4o-mini", + "temperature": 0.7, + "max_tokens": 4096, + "chat_history": True, + }, + } diff --git a/test/integration/server/bootstrap/test_bootstrap_configfile.py b/test/integration/server/bootstrap/test_bootstrap_configfile.py new file mode 100644 index 00000000..48cc2943 --- /dev/null +++ b/test/integration/server/bootstrap/test_bootstrap_configfile.py @@ -0,0 +1,245 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Integration tests for server/bootstrap/configfile.py + +Tests the ConfigStore class with real file I/O operations. +""" + +# pylint: disable=redefined-outer-name + +import json +import os +from pathlib import Path + +import pytest + +from server.bootstrap.configfile import config_file_path + + +class TestConfigStoreFileOperations: + """Integration tests for ConfigStore with real file operations.""" + + def test_load_valid_json_file(self, reset_config_store, make_config_file, sample_settings_config): + """ConfigStore should load a valid JSON configuration file.""" + config_path = make_config_file( + client_settings=sample_settings_config, + ) + + reset_config_store.load_from_file(config_path) + config = reset_config_store.get() + + assert config is not None + assert config.client_settings.client == "integration_client" + + def test_load_file_with_all_sections( + self, + reset_config_store, + make_config_file, + sample_settings_config, + sample_database_config, + sample_model_config, + sample_oci_config, + ): + """ConfigStore should load file with all configuration sections.""" + config_path = make_config_file( + client_settings=sample_settings_config, + database_configs=[sample_database_config], + model_configs=[sample_model_config], + oci_configs=[sample_oci_config], + ) + + reset_config_store.load_from_file(config_path) + config = reset_config_store.get() + + assert config is not None + assert len(config.database_configs) == 1 + assert config.database_configs[0].name == "INTEGRATION_DB" + assert len(config.model_configs) == 1 + assert config.model_configs[0].id == "integration-model" + assert len(config.oci_configs) == 1 + assert config.oci_configs[0].auth_profile == "INTEGRATION" + + def test_load_nonexistent_file_returns_none(self, reset_config_store, temp_dir): + """ConfigStore should handle nonexistent files gracefully.""" + nonexistent_path = temp_dir / "does_not_exist.json" + + reset_config_store.load_from_file(nonexistent_path) + config = reset_config_store.get() + + assert config is None + + def test_load_file_with_unicode_content(self, reset_config_store, temp_dir): + """ConfigStore should handle files with unicode content.""" + config_data = { + "client_settings": {"client": "unicode_test_客户端"}, + "database_configs": [], + "model_configs": [], + "oci_configs": [], + "prompt_configs": [], + } + + config_path = temp_dir / "unicode_config.json" + with open(config_path, "w", encoding="utf-8") as f: + json.dump(config_data, f, ensure_ascii=False) + + reset_config_store.load_from_file(config_path) + config = reset_config_store.get() + + assert config is not None + assert config.client_settings.client == "unicode_test_客户端" + + def test_load_file_with_nested_settings(self, reset_config_store, temp_dir): + """ConfigStore should handle deeply nested settings.""" + config_data = { + "client_settings": { + "client": "nested_test", + "ll_model": { + "model": "gpt-4o-mini", + "temperature": 0.5, + "max_tokens": 2048, + "chat_history": True, + }, + "vector_search": { + "discovery": True, + "rephrase": True, + "grade": True, + "top_k": 5, + }, + }, + "database_configs": [], + "model_configs": [], + "oci_configs": [], + "prompt_configs": [], + } + + config_path = temp_dir / "nested_config.json" + with open(config_path, "w", encoding="utf-8") as f: + json.dump(config_data, f) + + reset_config_store.load_from_file(config_path) + config = reset_config_store.get() + + assert config is not None + assert config.client_settings.ll_model.temperature == 0.5 + assert config.client_settings.vector_search.top_k == 5 + + def test_load_large_config_file(self, reset_config_store, temp_dir): + """ConfigStore should handle large configuration files.""" + # Create config with many database entries + database_configs = [ + { + "name": f"DB_{i}", + "user": f"user_{i}", + "password": f"pass_{i}", + "dsn": f"host{i}:1521/PDB{i}", + } + for i in range(50) + ] + + config_data = { + "client_settings": {"client": "large_test"}, + "database_configs": database_configs, + "model_configs": [], + "oci_configs": [], + "prompt_configs": [], + } + + config_path = temp_dir / "large_config.json" + with open(config_path, "w", encoding="utf-8") as f: + json.dump(config_data, f) + + reset_config_store.load_from_file(config_path) + config = reset_config_store.get() + + assert config is not None + assert len(config.database_configs) == 50 + + def test_load_file_preserves_field_types(self, reset_config_store, temp_dir): + """ConfigStore should preserve correct field types after loading.""" + config_data = { + "client_settings": { + "client": "type_test", + "ll_model": { + "model": "test-model", + "temperature": 0.7, # float + "max_tokens": 4096, # int + "chat_history": True, # bool + }, + }, + "database_configs": [], + "model_configs": [], + "oci_configs": [], + "prompt_configs": [], + } + + config_path = temp_dir / "types_config.json" + with open(config_path, "w", encoding="utf-8") as f: + json.dump(config_data, f) + + reset_config_store.load_from_file(config_path) + config = reset_config_store.get() + + assert isinstance(config.client_settings.ll_model.temperature, float) + assert isinstance(config.client_settings.ll_model.max_tokens, int) + assert isinstance(config.client_settings.ll_model.chat_history, bool) + + +class TestConfigStoreValidation: + """Integration tests for ConfigStore validation with real files.""" + + def test_load_file_validates_required_fields(self, reset_config_store, temp_dir): + """ConfigStore should validate required fields in config.""" + # Missing required 'client' field in client_settings + config_data = { + "client_settings": {}, # Missing 'client' + "database_configs": [], + "model_configs": [], + "oci_configs": [], + "prompt_configs": [], + } + + config_path = temp_dir / "invalid_config.json" + with open(config_path, "w", encoding="utf-8") as f: + json.dump(config_data, f) + + with pytest.raises(Exception): # Pydantic ValidationError + reset_config_store.load_from_file(config_path) + + def test_load_malformed_json_raises_error(self, reset_config_store, temp_dir): + """ConfigStore should raise error for malformed JSON.""" + config_path = temp_dir / "malformed.json" + with open(config_path, "w", encoding="utf-8") as f: + f.write("{ invalid json content }") + + with pytest.raises(json.JSONDecodeError): + reset_config_store.load_from_file(config_path) + + +class TestConfigFilePath: + """Integration tests for config_file_path function.""" + + def test_config_file_path_returns_valid_path(self): + """config_file_path should return a valid filesystem path.""" + path = config_file_path() + + assert path is not None + assert isinstance(path, str) + assert path.endswith("configuration.json") + + def test_config_file_path_parent_directory_structure(self): + """config_file_path should point to server/etc directory.""" + path = config_file_path() + path_obj = Path(path) + + # Parent should be 'etc' directory + assert path_obj.parent.name == "etc" + # Grandparent should be 'server' directory + assert path_obj.parent.parent.name == "server" + + def test_config_file_path_is_absolute(self): + """config_file_path should return an absolute path.""" + path = config_file_path() + + assert os.path.isabs(path) diff --git a/test/integration/server/bootstrap/test_bootstrap_databases.py b/test/integration/server/bootstrap/test_bootstrap_databases.py new file mode 100644 index 00000000..f46065b3 --- /dev/null +++ b/test/integration/server/bootstrap/test_bootstrap_databases.py @@ -0,0 +1,205 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Integration tests for server/bootstrap/databases.py + +Tests the database bootstrap process with real configuration files +and environment variables. +""" + +# pylint: disable=redefined-outer-name + +import os + +import pytest + +from server.bootstrap import databases as databases_module +from common.schema import Database + + +@pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") +class TestDatabasesBootstrapWithConfig: + """Integration tests for database bootstrap with configuration files.""" + + def test_bootstrap_returns_database_objects(self): + """databases.main() should return list of Database objects.""" + result = databases_module.main() + + assert isinstance(result, list) + assert all(isinstance(db, Database) for db in result) + + def test_bootstrap_creates_default_database(self): + """databases.main() should always create DEFAULT database.""" + result = databases_module.main() + + db_names = [db.name for db in result] + assert "DEFAULT" in db_names + + def test_bootstrap_with_config_file_databases(self, reset_config_store, make_config_file): + """databases.main() should load databases from config file.""" + config_path = make_config_file( + database_configs=[ + { + "name": "CONFIG_DB1", + "user": "config_user1", + "password": "config_pass1", + "dsn": "host1:1521/PDB1", + }, + { + "name": "CONFIG_DB2", + "user": "config_user2", + "password": "config_pass2", + "dsn": "host2:1521/PDB2", + }, + ], + ) + + reset_config_store.load_from_file(config_path) + result = databases_module.main() + + db_names = [db.name for db in result] + assert "CONFIG_DB1" in db_names + assert "CONFIG_DB2" in db_names + + def test_bootstrap_default_from_config_overridden_by_env(self, reset_config_store, make_config_file): + """databases.main() should override DEFAULT config values with env vars.""" + config_path = make_config_file( + database_configs=[ + { + "name": "DEFAULT", + "user": "config_user", + "password": "config_pass", + "dsn": "config_host:1521/CFGPDB", + }, + ], + ) + + os.environ["DB_USERNAME"] = "env_user" + os.environ["DB_PASSWORD"] = "env_password" + + try: + reset_config_store.load_from_file(config_path) + result = databases_module.main() + + default_db = next(db for db in result if db.name == "DEFAULT") + assert default_db.user == "env_user" + assert default_db.password == "env_password" + # DSN not in env, should keep config value + assert default_db.dsn == "config_host:1521/CFGPDB" + finally: + del os.environ["DB_USERNAME"] + del os.environ["DB_PASSWORD"] + + def test_bootstrap_raises_on_duplicate_names(self, reset_config_store, make_config_file): + """databases.main() should raise error for duplicate database names.""" + config_path = make_config_file( + database_configs=[ + {"name": "DUP_DB", "user": "user1", "password": "pass1", "dsn": "dsn1"}, + {"name": "dup_db", "user": "user2", "password": "pass2", "dsn": "dsn2"}, + ], + ) + + reset_config_store.load_from_file(config_path) + + with pytest.raises(ValueError, match="Duplicate database name"): + databases_module.main() + + +@pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") +class TestDatabasesBootstrapWithEnvVars: + """Integration tests for database bootstrap with environment variables.""" + + def test_bootstrap_uses_env_vars_for_default(self): + """databases.main() should use env vars for DEFAULT when no config.""" + os.environ["DB_USERNAME"] = "env_user" + os.environ["DB_PASSWORD"] = "env_password" + os.environ["DB_DSN"] = "env_host:1521/ENVPDB" + + try: + result = databases_module.main() + + default_db = next(db for db in result if db.name == "DEFAULT") + assert default_db.user == "env_user" + assert default_db.password == "env_password" + assert default_db.dsn == "env_host:1521/ENVPDB" + finally: + del os.environ["DB_USERNAME"] + del os.environ["DB_PASSWORD"] + del os.environ["DB_DSN"] + + def test_bootstrap_wallet_password_sets_wallet_location(self): + """databases.main() should set wallet_location when wallet_password present.""" + os.environ["DB_WALLET_PASSWORD"] = "wallet_secret" + os.environ["TNS_ADMIN"] = "/path/to/wallet" + + try: + result = databases_module.main() + + default_db = next(db for db in result if db.name == "DEFAULT") + assert default_db.wallet_password == "wallet_secret" + assert default_db.wallet_location == "/path/to/wallet" + assert default_db.config_dir == "/path/to/wallet" + finally: + del os.environ["DB_WALLET_PASSWORD"] + del os.environ["TNS_ADMIN"] + + def test_bootstrap_tns_admin_default(self): + """databases.main() should use 'tns_admin' as default config_dir.""" + result = databases_module.main() + + default_db = next(db for db in result if db.name == "DEFAULT") + assert default_db.config_dir == "tns_admin" + + +@pytest.mark.usefixtures("clean_bootstrap_env") +class TestDatabasesBootstrapPreservation: + """Integration tests for database bootstrap preserving non-DEFAULT databases.""" + + def test_bootstrap_preserves_non_default_databases(self, reset_config_store, make_config_file): + """databases.main() should not modify non-DEFAULT databases.""" + os.environ["DB_USERNAME"] = "should_not_apply" + + config_path = make_config_file( + database_configs=[ + { + "name": "CUSTOM_DB", + "user": "custom_user", + "password": "custom_pass", + "dsn": "custom:1521/CPDB", + }, + ], + ) + + try: + reset_config_store.load_from_file(config_path) + result = databases_module.main() + + custom_db = next(db for db in result if db.name == "CUSTOM_DB") + assert custom_db.user == "custom_user" + assert custom_db.password == "custom_pass" + finally: + del os.environ["DB_USERNAME"] + + def test_bootstrap_creates_default_when_not_in_config(self, reset_config_store, make_config_file): + """databases.main() should create DEFAULT from env when not in config.""" + os.environ["DB_USERNAME"] = "env_default_user" + + config_path = make_config_file( + database_configs=[ + {"name": "OTHER_DB", "user": "other", "password": "other", "dsn": "other"}, + ], + ) + + try: + reset_config_store.load_from_file(config_path) + result = databases_module.main() + + db_names = [db.name for db in result] + assert "DEFAULT" in db_names + assert "OTHER_DB" in db_names + + default_db = next(db for db in result if db.name == "DEFAULT") + assert default_db.user == "env_default_user" + finally: + del os.environ["DB_USERNAME"] diff --git a/test/integration/server/bootstrap/test_bootstrap_models.py b/test/integration/server/bootstrap/test_bootstrap_models.py new file mode 100644 index 00000000..52eb5ccc --- /dev/null +++ b/test/integration/server/bootstrap/test_bootstrap_models.py @@ -0,0 +1,272 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Integration tests for server/bootstrap/models.py + +Tests the models bootstrap process with real configuration files +and environment variables. +""" + +# pylint: disable=redefined-outer-name + +import os +from unittest.mock import patch + +import pytest + +from server.bootstrap import models as models_module +from common.schema import Model + + +@pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") +class TestModelsBootstrapBasic: + """Integration tests for basic models bootstrap functionality.""" + + def test_bootstrap_returns_model_objects(self): + """models.main() should return list of Model objects.""" + with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: + mock_accessible.return_value = (True, "OK") + result = models_module.main() + + assert isinstance(result, list) + assert all(isinstance(m, Model) for m in result) + + def test_bootstrap_includes_base_models(self): + """models.main() should include base model configurations.""" + with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: + mock_accessible.return_value = (True, "OK") + result = models_module.main() + + model_ids = [m.id for m in result] + # Check for some expected base models + assert "gpt-4o-mini" in model_ids + assert "command-r" in model_ids + + def test_bootstrap_includes_ll_and_embed_models(self): + """models.main() should include both LLM and embedding models.""" + with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: + mock_accessible.return_value = (True, "OK") + result = models_module.main() + + model_types = {m.type for m in result} + assert "ll" in model_types + assert "embed" in model_types + + +@pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") +class TestModelsBootstrapWithApiKeys: + """Integration tests for models bootstrap with API keys.""" + + def test_bootstrap_enables_models_with_openai_key(self): + """models.main() should enable OpenAI models when key is present.""" + os.environ["OPENAI_API_KEY"] = "test-openai-key" + + try: + with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: + mock_accessible.return_value = (True, "OK") + result = models_module.main() + + openai_model = next(m for m in result if m.id == "gpt-4o-mini") + assert openai_model.enabled is True + assert openai_model.api_key == "test-openai-key" + finally: + del os.environ["OPENAI_API_KEY"] + + def test_bootstrap_enables_models_with_cohere_key(self): + """models.main() should enable Cohere models when key is present.""" + os.environ["COHERE_API_KEY"] = "test-cohere-key" + + try: + with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: + mock_accessible.return_value = (True, "OK") + result = models_module.main() + + cohere_model = next(m for m in result if m.id == "command-r") + assert cohere_model.enabled is True + assert cohere_model.api_key == "test-cohere-key" + finally: + del os.environ["COHERE_API_KEY"] + + def test_bootstrap_disables_models_without_keys(self): + """models.main() should disable models when API keys are not present.""" + with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: + mock_accessible.return_value = (True, "OK") + result = models_module.main() + + # Without OPENAI_API_KEY, the model should be disabled + openai_model = next(m for m in result if m.id == "gpt-4o-mini") + assert openai_model.enabled is False + + +@pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") +class TestModelsBootstrapWithOnPremUrls: + """Integration tests for models bootstrap with on-prem URLs.""" + + def test_bootstrap_enables_ollama_with_url(self): + """models.main() should enable Ollama models when URL is set.""" + os.environ["ON_PREM_OLLAMA_URL"] = "http://localhost:11434" + + try: + with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: + mock_accessible.return_value = (True, "OK") + result = models_module.main() + + ollama_model = next(m for m in result if m.id == "llama3.1") + assert ollama_model.enabled is True + assert ollama_model.api_base == "http://localhost:11434" + finally: + del os.environ["ON_PREM_OLLAMA_URL"] + + def test_bootstrap_checks_url_accessibility(self): + """models.main() should check URL accessibility for enabled models.""" + os.environ["ON_PREM_OLLAMA_URL"] = "http://localhost:11434" + + try: + with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: + mock_accessible.return_value = (False, "Connection refused") + result = models_module.main() + + ollama_model = next(m for m in result if m.id == "llama3.1") + # Should be disabled if URL is not accessible + assert ollama_model.enabled is False + finally: + del os.environ["ON_PREM_OLLAMA_URL"] + + +@pytest.mark.usefixtures("clean_bootstrap_env") +class TestModelsBootstrapWithConfigStore: + """Integration tests for models bootstrap with ConfigStore configuration.""" + + def test_bootstrap_merges_config_store_models(self, reset_config_store, make_config_file): + """models.main() should merge models from ConfigStore.""" + config_path = make_config_file( + model_configs=[ + { + "id": "custom-model", + "type": "ll", + "provider": "custom", + "enabled": True, + "api_base": "https://custom.api/v1", + "api_key": "custom-key", + }, + ], + ) + + try: + reset_config_store.load_from_file(config_path) + + with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: + mock_accessible.return_value = (True, "OK") + result = models_module.main() + + model_ids = [m.id for m in result] + assert "custom-model" in model_ids + + custom_model = next(m for m in result if m.id == "custom-model") + assert custom_model.provider == "custom" + assert custom_model.api_base == "https://custom.api/v1" + finally: + pass + + def test_bootstrap_config_store_overrides_base_model(self, reset_config_store, make_config_file): + """models.main() should let ConfigStore override base model settings.""" + config_path = make_config_file( + model_configs=[ + { + "id": "gpt-4o-mini", + "type": "ll", + "provider": "openai", + "enabled": True, + "api_base": "https://api.openai.com/v1", + "api_key": "override-key", + "max_tokens": 9999, + }, + ], + ) + + try: + reset_config_store.load_from_file(config_path) + + with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: + mock_accessible.return_value = (True, "OK") + result = models_module.main() + + openai_model = next(m for m in result if m.id == "gpt-4o-mini") + assert openai_model.api_key == "override-key" + assert openai_model.max_tokens == 9999 + finally: + pass + + +@pytest.mark.usefixtures("clean_bootstrap_env") +class TestModelsBootstrapDuplicateDetection: + """Integration tests for models bootstrap duplicate detection.""" + + def test_bootstrap_deduplicates_config_store_models(self, reset_config_store, make_config_file): + """models.main() should deduplicate models with same provider+id in ConfigStore. + + Note: ConfigStore models with the same (provider, id) key are deduplicated + during the merge process (dict keyed by tuple keeps last value). + This is different from base model duplicate detection which raises an error. + """ + # Create config with duplicate model (same provider + id) + config_path = make_config_file( + model_configs=[ + { + "id": "duplicate-model", + "type": "ll", + "provider": "test", + "api_base": "http://test1", + }, + { + "id": "duplicate-model", + "type": "ll", + "provider": "test", + "api_base": "http://test2", + }, + ], + ) + + reset_config_store.load_from_file(config_path) + + with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: + mock_accessible.return_value = (True, "OK") + result = models_module.main() + + # Should have only one model with the duplicate id (last one wins) + dup_models = [m for m in result if m.id == "duplicate-model"] + assert len(dup_models) == 1 + # The last entry in the config should win + assert dup_models[0].api_base == "http://test2" + + def test_bootstrap_allows_same_id_different_provider(self, reset_config_store, make_config_file): + """models.main() should allow same ID with different providers.""" + config_path = make_config_file( + model_configs=[ + { + "id": "shared-model-name", + "type": "ll", + "provider": "provider1", + "api_base": "http://provider1", + }, + { + "id": "shared-model-name", + "type": "ll", + "provider": "provider2", + "api_base": "http://provider2", + }, + ], + ) + + reset_config_store.load_from_file(config_path) + + with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: + mock_accessible.return_value = (True, "OK") + result = models_module.main() + + # Both should be present + shared_models = [m for m in result if m.id == "shared-model-name"] + assert len(shared_models) == 2 + providers = {m.provider for m in shared_models} + assert providers == {"provider1", "provider2"} diff --git a/test/integration/server/bootstrap/test_bootstrap_oci.py b/test/integration/server/bootstrap/test_bootstrap_oci.py new file mode 100644 index 00000000..4d4afd47 --- /dev/null +++ b/test/integration/server/bootstrap/test_bootstrap_oci.py @@ -0,0 +1,246 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Integration tests for server/bootstrap/oci.py + +Tests the OCI bootstrap process with real configuration files +and environment variables. +""" + +# pylint: disable=redefined-outer-name + +import os + +import oci +import pytest + +from server.bootstrap import oci as oci_module +from common.schema import OracleCloudSettings + + +@pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") +class TestOciBootstrapWithEnvVars: + """Integration tests for OCI bootstrap with environment variables.""" + + def test_bootstrap_returns_oci_settings_objects(self): + """oci.main() should return list of OracleCloudSettings objects.""" + # Point to nonexistent OCI config to test env var path + os.environ["OCI_CLI_CONFIG_FILE"] = "/nonexistent/oci/config" + + try: + result = oci_module.main() + + assert isinstance(result, list) + assert all(isinstance(s, OracleCloudSettings) for s in result) + finally: + del os.environ["OCI_CLI_CONFIG_FILE"] + + def test_bootstrap_creates_default_profile(self): + """oci.main() should always create DEFAULT profile.""" + os.environ["OCI_CLI_CONFIG_FILE"] = "/nonexistent/oci/config" + + try: + result = oci_module.main() + + profile_names = [s.auth_profile for s in result] + assert oci.config.DEFAULT_PROFILE in profile_names + finally: + del os.environ["OCI_CLI_CONFIG_FILE"] + + def test_bootstrap_applies_tenancy_env_var(self): + """oci.main() should apply OCI_CLI_TENANCY to DEFAULT profile.""" + os.environ["OCI_CLI_CONFIG_FILE"] = "/nonexistent/oci/config" + os.environ["OCI_CLI_TENANCY"] = "ocid1.tenancy.oc1..envtenancy" + + try: + result = oci_module.main() + + default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) + assert default_profile.tenancy == "ocid1.tenancy.oc1..envtenancy" + finally: + del os.environ["OCI_CLI_CONFIG_FILE"] + del os.environ["OCI_CLI_TENANCY"] + + def test_bootstrap_applies_region_env_var(self): + """oci.main() should apply OCI_CLI_REGION to DEFAULT profile.""" + os.environ["OCI_CLI_CONFIG_FILE"] = "/nonexistent/oci/config" + os.environ["OCI_CLI_REGION"] = "us-chicago-1" + + try: + result = oci_module.main() + + default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) + assert default_profile.region == "us-chicago-1" + finally: + del os.environ["OCI_CLI_CONFIG_FILE"] + del os.environ["OCI_CLI_REGION"] + + def test_bootstrap_applies_genai_env_vars(self): + """oci.main() should apply GenAI environment variables.""" + os.environ["OCI_CLI_CONFIG_FILE"] = "/nonexistent/oci/config" + os.environ["OCI_GENAI_COMPARTMENT_ID"] = "ocid1.compartment.oc1..genaicomp" + os.environ["OCI_GENAI_REGION"] = "us-chicago-1" + + try: + result = oci_module.main() + + default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) + assert default_profile.genai_compartment_id == "ocid1.compartment.oc1..genaicomp" + assert default_profile.genai_region == "us-chicago-1" + finally: + del os.environ["OCI_CLI_CONFIG_FILE"] + del os.environ["OCI_GENAI_COMPARTMENT_ID"] + del os.environ["OCI_GENAI_REGION"] + + def test_bootstrap_explicit_auth_method(self): + """oci.main() should use OCI_CLI_AUTH when specified.""" + os.environ["OCI_CLI_CONFIG_FILE"] = "/nonexistent/oci/config" + os.environ["OCI_CLI_AUTH"] = "instance_principal" + + try: + result = oci_module.main() + + default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) + assert default_profile.authentication == "instance_principal" + finally: + del os.environ["OCI_CLI_CONFIG_FILE"] + del os.environ["OCI_CLI_AUTH"] + + def test_bootstrap_default_auth_is_api_key(self): + """oci.main() should default to api_key authentication.""" + os.environ["OCI_CLI_CONFIG_FILE"] = "/nonexistent/oci/config" + + try: + result = oci_module.main() + + default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) + assert default_profile.authentication == "api_key" + finally: + del os.environ["OCI_CLI_CONFIG_FILE"] + + +@pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") +class TestOciBootstrapWithConfigFile: + """Integration tests for OCI bootstrap with real OCI config files.""" + + def test_bootstrap_reads_oci_config_file(self, make_oci_config_file): + """oci.main() should read profiles from OCI config file.""" + config_path = make_oci_config_file( + profiles={ + "DEFAULT": { + "tenancy": "ocid1.tenancy.oc1..filetenancy", + "region": "us-ashburn-1", + "fingerprint": "file:fingerprint", + }, + } + ) + + os.environ["OCI_CLI_CONFIG_FILE"] = str(config_path) + + try: + result = oci_module.main() + + # Should have loaded the profile from file + profile_names = [s.auth_profile for s in result] + assert "DEFAULT" in profile_names + finally: + del os.environ["OCI_CLI_CONFIG_FILE"] + + def test_bootstrap_loads_multiple_profiles(self, make_oci_config_file): + """oci.main() should load multiple profiles from OCI config file.""" + config_path = make_oci_config_file( + profiles={ + "DEFAULT": { + "tenancy": "ocid1.tenancy.oc1..default", + "region": "us-ashburn-1", + "fingerprint": "default:fp", + }, + "PRODUCTION": { + "tenancy": "ocid1.tenancy.oc1..production", + "region": "us-phoenix-1", + "fingerprint": "prod:fp", + }, + } + ) + + os.environ["OCI_CLI_CONFIG_FILE"] = str(config_path) + + try: + result = oci_module.main() + + profile_names = [s.auth_profile for s in result] + assert "DEFAULT" in profile_names + assert "PRODUCTION" in profile_names + finally: + del os.environ["OCI_CLI_CONFIG_FILE"] + + +@pytest.mark.usefixtures("clean_bootstrap_env") +class TestOciBootstrapWithConfigStore: + """Integration tests for OCI bootstrap with ConfigStore configuration.""" + + def test_bootstrap_merges_config_store_profiles(self, reset_config_store, make_config_file): + """oci.main() should merge profiles from ConfigStore.""" + os.environ["OCI_CLI_CONFIG_FILE"] = "/nonexistent/oci/config" + + config_path = make_config_file( + oci_configs=[ + { + "auth_profile": "CONFIGSTORE_PROFILE", + "tenancy": "ocid1.tenancy.oc1..configstore", + "region": "us-sanjose-1", + "fingerprint": "cs:fingerprint", + }, + ], + ) + + try: + reset_config_store.load_from_file(config_path) + result = oci_module.main() + + profile_names = [s.auth_profile for s in result] + assert "CONFIGSTORE_PROFILE" in profile_names + + cs_profile = next(p for p in result if p.auth_profile == "CONFIGSTORE_PROFILE") + assert cs_profile.tenancy == "ocid1.tenancy.oc1..configstore" + assert cs_profile.region == "us-sanjose-1" + finally: + del os.environ["OCI_CLI_CONFIG_FILE"] + + def test_bootstrap_config_store_overrides_file_profile( + self, reset_config_store, make_config_file, make_oci_config_file + ): + """oci.main() should let ConfigStore override file profiles.""" + oci_config_path = make_oci_config_file( + profiles={ + "DEFAULT": { + "tenancy": "ocid1.tenancy.oc1..fromfile", + "region": "us-ashburn-1", + "fingerprint": "file:fp", + }, + } + ) + + config_path = make_config_file( + oci_configs=[ + { + "auth_profile": "DEFAULT", + "tenancy": "ocid1.tenancy.oc1..fromconfigstore", + "region": "us-phoenix-1", + "fingerprint": "cs:fp", + }, + ], + ) + + os.environ["OCI_CLI_CONFIG_FILE"] = str(oci_config_path) + + try: + reset_config_store.load_from_file(config_path) + result = oci_module.main() + + default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) + # ConfigStore should override file values + assert default_profile.tenancy == "ocid1.tenancy.oc1..fromconfigstore" + finally: + del os.environ["OCI_CLI_CONFIG_FILE"] diff --git a/test/integration/server/bootstrap/test_bootstrap_settings.py b/test/integration/server/bootstrap/test_bootstrap_settings.py new file mode 100644 index 00000000..1d71376c --- /dev/null +++ b/test/integration/server/bootstrap/test_bootstrap_settings.py @@ -0,0 +1,170 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Integration tests for server/bootstrap/settings.py + +Tests the settings bootstrap process with real configuration files. +""" + +# pylint: disable=redefined-outer-name protected-access too-few-public-methods + +import pytest + +from server.bootstrap import settings as settings_module +from common.schema import Settings + + +@pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") +class TestSettingsBootstrapWithConfig: + """Integration tests for settings bootstrap with configuration files.""" + + def test_bootstrap_creates_default_and_server_clients(self): + """settings.main() should always create default and server clients.""" + result = settings_module.main() + + assert len(result) == 2 + client_names = [s.client for s in result] + assert "default" in client_names + assert "server" in client_names + + def test_bootstrap_returns_settings_objects(self): + """settings.main() should return list of Settings objects.""" + result = settings_module.main() + + assert all(isinstance(s, Settings) for s in result) + + def test_bootstrap_with_config_file(self, reset_config_store, make_config_file): + """settings.main() should use settings from config file.""" + config_path = make_config_file( + client_settings={ + "client": "config_client", + "ll_model": { + "model": "custom-model", + "temperature": 0.9, + "max_tokens": 8192, + "chat_history": False, + }, + }, + ) + + reset_config_store.load_from_file(config_path) + result = settings_module.main() + + # All clients should inherit config file settings + for s in result: + assert s.ll_model.model == "custom-model" + assert s.ll_model.temperature == 0.9 + assert s.ll_model.max_tokens == 8192 + assert s.ll_model.chat_history is False + + def test_bootstrap_overrides_client_names(self, reset_config_store, make_config_file): + """settings.main() should override client field to default/server.""" + config_path = make_config_file( + client_settings={ + "client": "original_client_name", + }, + ) + + reset_config_store.load_from_file(config_path) + result = settings_module.main() + + client_names = [s.client for s in result] + assert "original_client_name" not in client_names + assert "default" in client_names + assert "server" in client_names + + def test_bootstrap_with_vector_search_settings(self, reset_config_store, make_config_file): + """settings.main() should load vector search settings from config.""" + config_path = make_config_file( + client_settings={ + "client": "vs_client", + "vector_search": { + "discovery": False, + "rephrase": False, + "grade": True, + "top_k": 10, + "search_type": "Similarity", + }, + }, + ) + + reset_config_store.load_from_file(config_path) + result = settings_module.main() + + for s in result: + assert s.vector_search.discovery is False + assert s.vector_search.rephrase is False + assert s.vector_search.grade is True + assert s.vector_search.top_k == 10 + + def test_bootstrap_with_oci_settings(self, reset_config_store, make_config_file): + """settings.main() should load OCI settings from config.""" + config_path = make_config_file( + client_settings={ + "client": "oci_client", + "oci": { + "auth_profile": "CUSTOM_PROFILE", + }, + }, + ) + + reset_config_store.load_from_file(config_path) + result = settings_module.main() + + for s in result: + assert s.oci.auth_profile == "CUSTOM_PROFILE" + + def test_bootstrap_with_database_settings(self, reset_config_store, make_config_file): + """settings.main() should load database settings from config.""" + config_path = make_config_file( + client_settings={ + "client": "db_client", + "database": { + "alias": "CUSTOM_DB", + }, + }, + ) + + reset_config_store.load_from_file(config_path) + result = settings_module.main() + + for s in result: + assert s.database.alias == "CUSTOM_DB" + + +@pytest.mark.usefixtures("clean_bootstrap_env") +class TestSettingsBootstrapWithoutConfig: + """Integration tests for settings bootstrap without configuration.""" + + def test_bootstrap_without_config_uses_defaults(self, reset_config_store): + """settings.main() should use default values without config file.""" + # Ensure no config is loaded + assert reset_config_store.get() is None + + result = settings_module.main() + + assert len(result) == 2 + # Should have default Settings values + for s in result: + assert isinstance(s, Settings) + # Default values from Settings model + assert s.oci.auth_profile == "DEFAULT" + assert s.database.alias == "DEFAULT" + + +@pytest.mark.usefixtures("clean_bootstrap_env") +class TestSettingsBootstrapIdempotency: + """Integration tests for settings bootstrap idempotency.""" + + def test_bootstrap_produces_consistent_results(self, reset_config_store): + """settings.main() should produce consistent results on multiple calls.""" + result1 = settings_module.main() + + # Reset and call again + reset_config_store._config = None + result2 = settings_module.main() + + assert len(result1) == len(result2) + for s1, s2 in zip(result1, result2): + assert s1.client == s2.client diff --git a/test/unit/server/bootstrap/__init__.py b/test/unit/server/bootstrap/__init__.py new file mode 100644 index 00000000..170366b5 --- /dev/null +++ b/test/unit/server/bootstrap/__init__.py @@ -0,0 +1 @@ +# Bootstrap unit test package diff --git a/test/unit/server/bootstrap/conftest.py b/test/unit/server/bootstrap/conftest.py new file mode 100644 index 00000000..aff230bb --- /dev/null +++ b/test/unit/server/bootstrap/conftest.py @@ -0,0 +1,286 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Pytest fixtures for server/bootstrap unit tests. +""" + +# pylint: disable=redefined-outer-name protected-access too-few-public-methods + +import json +import os +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from common.schema import ( + Configuration, + Database, + Model, + OracleCloudSettings, + Settings, + LargeLanguageSettings, +) +from server.bootstrap.configfile import ConfigStore + + +@pytest.fixture +def make_database(): + """Factory fixture to create Database objects.""" + + def _make_database( + name: str = "TEST_DB", + user: str = "test_user", + password: str = "test_password", + dsn: str = "localhost:1521/TESTPDB", + wallet_password: str = None, + **kwargs, + ) -> Database: + return Database( + name=name, + user=user, + password=password, + dsn=dsn, + wallet_password=wallet_password, + **kwargs, + ) + + return _make_database + + +@pytest.fixture +def make_model(): + """Factory fixture to create Model objects.""" + + def _make_model( + model_id: str = "gpt-4o-mini", + model_type: str = "ll", + provider: str = "openai", + enabled: bool = True, + api_key: str = "test-key", + api_base: str = "https://api.openai.com/v1", + **kwargs, + ) -> Model: + return Model( + id=model_id, + type=model_type, + provider=provider, + enabled=enabled, + api_key=api_key, + api_base=api_base, + **kwargs, + ) + + return _make_model + + +@pytest.fixture +def make_oci_config(): + """Factory fixture to create OracleCloudSettings objects. + + Note: The 'user' field requires OCID format pattern matching. + Use None to skip the user field in tests that don't need it. + """ + + def _make_oci_config( + auth_profile: str = "DEFAULT", + tenancy: str = "test-tenancy", + region: str = "us-ashburn-1", + user: str = None, # Use None by default - OCID pattern required + fingerprint: str = "test-fingerprint", + key_file: str = "/path/to/key", + **kwargs, + ) -> OracleCloudSettings: + return OracleCloudSettings( + auth_profile=auth_profile, + tenancy=tenancy, + region=region, + user=user, + fingerprint=fingerprint, + key_file=key_file, + **kwargs, + ) + + return _make_oci_config + + +@pytest.fixture +def make_ll_settings(): + """Factory fixture to create LargeLanguageSettings objects.""" + + def _make_ll_settings( + model: str = "gpt-4o-mini", + temperature: float = 0.7, + max_tokens: int = 4096, + chat_history: bool = True, + **kwargs, + ) -> LargeLanguageSettings: + return LargeLanguageSettings( + model=model, + temperature=temperature, + max_tokens=max_tokens, + chat_history=chat_history, + **kwargs, + ) + + return _make_ll_settings + + +@pytest.fixture +def make_settings(make_ll_settings): + """Factory fixture to create Settings objects.""" + + def _make_settings( + client: str = "test_client", + ll_model: LargeLanguageSettings = None, + **kwargs, + ) -> Settings: + if ll_model is None: + ll_model = make_ll_settings() + return Settings( + client=client, + ll_model=ll_model, + **kwargs, + ) + + return _make_settings + + +@pytest.fixture +def make_configuration(make_settings): + """Factory fixture to create Configuration objects.""" + + def _make_configuration( + client_settings: Settings = None, + database_configs: list = None, + model_configs: list = None, + oci_configs: list = None, + **kwargs, + ) -> Configuration: + return Configuration( + client_settings=client_settings or make_settings(), + database_configs=database_configs or [], + model_configs=model_configs or [], + oci_configs=oci_configs or [], + prompt_configs=[], + **kwargs, + ) + + return _make_configuration + + +@pytest.fixture +def temp_config_file(make_settings): + """Create a temporary configuration JSON file.""" + + def _create_temp_config( + client_settings: Settings = None, + database_configs: list = None, + model_configs: list = None, + oci_configs: list = None, + ): + config_data = { + "client_settings": (client_settings or make_settings()).model_dump(), + "database_configs": [ + (db if isinstance(db, dict) else db.model_dump()) + for db in (database_configs or []) + ], + "model_configs": [ + (m if isinstance(m, dict) else m.model_dump()) + for m in (model_configs or []) + ], + "oci_configs": [ + (o if isinstance(o, dict) else o.model_dump()) + for o in (oci_configs or []) + ], + "prompt_configs": [], + } + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False, encoding="utf-8" + ) as temp_file: + json.dump(config_data, temp_file) + return Path(temp_file.name) + + return _create_temp_config + + +@pytest.fixture +def reset_config_store(): + """Reset ConfigStore singleton state before and after each test.""" + # Reset before test + ConfigStore._config = None + + yield ConfigStore + + # Reset after test + ConfigStore._config = None + + +@pytest.fixture +def mock_oci_config_parser(): + """Mock OCI config parser for testing OCI bootstrap.""" + with patch("configparser.ConfigParser") as mock_parser: + mock_instance = MagicMock() + mock_instance.sections.return_value = [] + mock_parser.return_value = mock_instance + yield mock_parser + + +@pytest.fixture +def mock_oci_config_from_file(): + """Mock oci.config.from_file for testing OCI bootstrap.""" + with patch("oci.config.from_file") as mock_from_file: + yield mock_from_file + + +@pytest.fixture +def mock_is_url_accessible(): + """Mock is_url_accessible for testing model bootstrap.""" + with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: + mock_accessible.return_value = (True, "OK") + yield mock_accessible + + +@pytest.fixture +def clean_env(): + """Fixture to temporarily clear relevant environment variables.""" + env_vars = [ + "DB_USERNAME", + "DB_PASSWORD", + "DB_DSN", + "DB_WALLET_PASSWORD", + "TNS_ADMIN", + "OPENAI_API_KEY", + "COHERE_API_KEY", + "PPLX_API_KEY", + "ON_PREM_OLLAMA_URL", + "ON_PREM_VLLM_URL", + "ON_PREM_HF_URL", + "OCI_CLI_CONFIG_FILE", + "OCI_CLI_TENANCY", + "OCI_CLI_REGION", + "OCI_CLI_USER", + "OCI_CLI_FINGERPRINT", + "OCI_CLI_KEY_FILE", + "OCI_CLI_SECURITY_TOKEN_FILE", + "OCI_CLI_AUTH", + "OCI_GENAI_COMPARTMENT_ID", + "OCI_GENAI_REGION", + "OCI_GENAI_SERVICE_ENDPOINT", + ] + + original_values = {} + for var in env_vars: + original_values[var] = os.environ.pop(var, None) + + yield + + # Restore original values + for var, value in original_values.items(): + if value is not None: + os.environ[var] = value + elif var in os.environ: + del os.environ[var] diff --git a/test/unit/server/bootstrap/test_bootstrap_bootstrap.py b/test/unit/server/bootstrap/test_bootstrap_bootstrap.py new file mode 100644 index 00000000..542baee1 --- /dev/null +++ b/test/unit/server/bootstrap/test_bootstrap_bootstrap.py @@ -0,0 +1,183 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/bootstrap/bootstrap.py +Tests for the main bootstrap module that coordinates all bootstrap operations. +""" + +# pylint: disable=redefined-outer-name protected-access too-few-public-methods +# pylint: disable=import-outside-toplevel + +import importlib +from unittest.mock import patch + +from server.bootstrap import bootstrap + + +class TestBootstrapModule: + """Tests for the bootstrap module initialization.""" + + def test_database_objects_is_list(self): + """DATABASE_OBJECTS should be a list.""" + with patch("server.bootstrap.databases.main") as mock_databases: + mock_databases.return_value = [] + with patch("server.bootstrap.models.main") as mock_models: + mock_models.return_value = [] + with patch("server.bootstrap.oci.main") as mock_oci: + mock_oci.return_value = [] + with patch("server.bootstrap.settings.main") as mock_settings: + mock_settings.return_value = [] + + # Reload to trigger module-level code with mocks + importlib.reload(bootstrap) + + assert isinstance(bootstrap.DATABASE_OBJECTS, list) + + def test_model_objects_is_list(self): + """MODEL_OBJECTS should be a list.""" + with patch("server.bootstrap.databases.main") as mock_databases: + mock_databases.return_value = [] + with patch("server.bootstrap.models.main") as mock_models: + mock_models.return_value = [] + with patch("server.bootstrap.oci.main") as mock_oci: + mock_oci.return_value = [] + with patch("server.bootstrap.settings.main") as mock_settings: + mock_settings.return_value = [] + + importlib.reload(bootstrap) + + assert isinstance(bootstrap.MODEL_OBJECTS, list) + + def test_oci_objects_is_list(self): + """OCI_OBJECTS should be a list.""" + with patch("server.bootstrap.databases.main") as mock_databases: + mock_databases.return_value = [] + with patch("server.bootstrap.models.main") as mock_models: + mock_models.return_value = [] + with patch("server.bootstrap.oci.main") as mock_oci: + mock_oci.return_value = [] + with patch("server.bootstrap.settings.main") as mock_settings: + mock_settings.return_value = [] + + importlib.reload(bootstrap) + + assert isinstance(bootstrap.OCI_OBJECTS, list) + + def test_settings_objects_is_list(self): + """SETTINGS_OBJECTS should be a list.""" + with patch("server.bootstrap.databases.main") as mock_databases: + mock_databases.return_value = [] + with patch("server.bootstrap.models.main") as mock_models: + mock_models.return_value = [] + with patch("server.bootstrap.oci.main") as mock_oci: + mock_oci.return_value = [] + with patch("server.bootstrap.settings.main") as mock_settings: + mock_settings.return_value = [] + + importlib.reload(bootstrap) + + assert isinstance(bootstrap.SETTINGS_OBJECTS, list) + + def test_calls_all_bootstrap_functions(self): + """Bootstrap module should call all main() functions.""" + with patch("server.bootstrap.databases.main") as mock_databases: + mock_databases.return_value = [] + with patch("server.bootstrap.models.main") as mock_models: + mock_models.return_value = [] + with patch("server.bootstrap.oci.main") as mock_oci: + mock_oci.return_value = [] + with patch("server.bootstrap.settings.main") as mock_settings: + mock_settings.return_value = [] + + importlib.reload(bootstrap) + + mock_databases.assert_called_once() + mock_models.assert_called_once() + mock_oci.assert_called_once() + mock_settings.assert_called_once() + + def test_stores_database_results(self, make_database): + """Bootstrap module should store database.main() results.""" + db1 = make_database(name="DB1") + db2 = make_database(name="DB2") + + with patch("server.bootstrap.databases.main") as mock_databases: + mock_databases.return_value = [db1, db2] + with patch("server.bootstrap.models.main") as mock_models: + mock_models.return_value = [] + with patch("server.bootstrap.oci.main") as mock_oci: + mock_oci.return_value = [] + with patch("server.bootstrap.settings.main") as mock_settings: + mock_settings.return_value = [] + + importlib.reload(bootstrap) + + assert len(bootstrap.DATABASE_OBJECTS) == 2 + assert bootstrap.DATABASE_OBJECTS[0].name == "DB1" + + def test_stores_model_results(self, make_model): + """Bootstrap module should store models.main() results.""" + model1 = make_model(model_id="model1") + model2 = make_model(model_id="model2") + + with patch("server.bootstrap.databases.main") as mock_databases: + mock_databases.return_value = [] + with patch("server.bootstrap.models.main") as mock_models: + mock_models.return_value = [model1, model2] + with patch("server.bootstrap.oci.main") as mock_oci: + mock_oci.return_value = [] + with patch("server.bootstrap.settings.main") as mock_settings: + mock_settings.return_value = [] + + importlib.reload(bootstrap) + + assert len(bootstrap.MODEL_OBJECTS) == 2 + + def test_stores_oci_results(self, make_oci_config): + """Bootstrap module should store oci.main() results.""" + oci1 = make_oci_config(auth_profile="PROFILE1") + oci2 = make_oci_config(auth_profile="PROFILE2") + + with patch("server.bootstrap.databases.main") as mock_databases: + mock_databases.return_value = [] + with patch("server.bootstrap.models.main") as mock_models: + mock_models.return_value = [] + with patch("server.bootstrap.oci.main") as mock_oci: + mock_oci.return_value = [oci1, oci2] + with patch("server.bootstrap.settings.main") as mock_settings: + mock_settings.return_value = [] + + importlib.reload(bootstrap) + + assert len(bootstrap.OCI_OBJECTS) == 2 + + def test_stores_settings_results(self, make_settings): + """Bootstrap module should store settings.main() results.""" + settings1 = make_settings(client="client1") + settings2 = make_settings(client="client2") + + with patch("server.bootstrap.databases.main") as mock_databases: + mock_databases.return_value = [] + with patch("server.bootstrap.models.main") as mock_models: + mock_models.return_value = [] + with patch("server.bootstrap.oci.main") as mock_oci: + mock_oci.return_value = [] + with patch("server.bootstrap.settings.main") as mock_settings: + mock_settings.return_value = [settings1, settings2] + + importlib.reload(bootstrap) + + assert len(bootstrap.SETTINGS_OBJECTS) == 2 + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured in bootstrap module.""" + assert hasattr(bootstrap, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert bootstrap.logger.name == "bootstrap" diff --git a/test/unit/server/bootstrap/test_bootstrap_configfile.py b/test/unit/server/bootstrap/test_bootstrap_configfile.py new file mode 100644 index 00000000..12c505ca --- /dev/null +++ b/test/unit/server/bootstrap/test_bootstrap_configfile.py @@ -0,0 +1,229 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/bootstrap/configfile.py +Tests for ConfigStore class and config_file_path function. +""" + +# pylint: disable=redefined-outer-name protected-access too-few-public-methods + +import json +import os +import tempfile +from pathlib import Path +from threading import Thread, Barrier + +import pytest + +from server.bootstrap import configfile +from server.bootstrap.configfile import config_file_path + + +class TestConfigStore: + """Tests for the ConfigStore class.""" + + def test_load_from_file_success(self, reset_config_store, temp_config_file, make_settings): + """ConfigStore should load configuration from a valid JSON file.""" + settings = make_settings(client="test_client") + config_path = temp_config_file(client_settings=settings) + + try: + reset_config_store.load_from_file(config_path) + config = reset_config_store.get() + + assert config is not None + assert config.client_settings.client == "test_client" + finally: + os.unlink(config_path) + + def test_load_from_file_nonexistent_file(self, reset_config_store): + """ConfigStore should handle nonexistent files gracefully.""" + nonexistent_path = Path("/nonexistent/path/config.json") + + reset_config_store.load_from_file(nonexistent_path) + config = reset_config_store.get() + + assert config is None + + def test_load_from_file_wrong_extension_warns(self, reset_config_store, caplog): + """ConfigStore should warn when file has wrong extension.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, encoding="utf-8") as temp_file: + # Need valid client_settings with required 'client' field + json.dump( + { + "client_settings": {"client": "test"}, + "database_configs": [], + "model_configs": [], + "oci_configs": [], + "prompt_configs": [], + }, + temp_file, + ) + temp_path = Path(temp_file.name) + + try: + reset_config_store.load_from_file(temp_path) + assert "should be a .json file" in caplog.text + finally: + os.unlink(temp_path) + + def test_load_from_file_only_loads_once(self, reset_config_store, temp_config_file, make_settings): + """ConfigStore should only load configuration once (singleton pattern).""" + settings1 = make_settings(client="first_client") + settings2 = make_settings(client="second_client") + + config_path1 = temp_config_file(client_settings=settings1) + config_path2 = temp_config_file(client_settings=settings2) + + try: + reset_config_store.load_from_file(config_path1) + reset_config_store.load_from_file(config_path2) # Should be ignored + + config = reset_config_store.get() + assert config.client_settings.client == "first_client" + finally: + os.unlink(config_path1) + os.unlink(config_path2) + + def test_load_from_file_thread_safety(self, reset_config_store, temp_config_file, make_settings): + """ConfigStore should handle concurrent loading safely.""" + settings = make_settings(client="thread_test") + config_path = temp_config_file(client_settings=settings) + + num_threads = 5 + barrier = Barrier(num_threads) + results = [] + + def load_config(): + barrier.wait() # Synchronize threads + reset_config_store.load_from_file(config_path) + results.append(reset_config_store.get()) + + try: + threads = [Thread(target=load_config) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + # All threads should see the same config + assert len(results) == num_threads + assert all(r is not None for r in results) + assert all(r.client_settings.client == "thread_test" for r in results) + finally: + os.unlink(config_path) + + def test_load_from_file_with_database_configs( + self, reset_config_store, temp_config_file, make_settings, make_database + ): + """ConfigStore should load database configurations.""" + settings = make_settings() + db = make_database(name="TEST_DB", user="admin") + config_path = temp_config_file(client_settings=settings, database_configs=[db]) + + try: + reset_config_store.load_from_file(config_path) + config = reset_config_store.get() + + assert config is not None + assert len(config.database_configs) == 1 + assert config.database_configs[0].name == "TEST_DB" + assert config.database_configs[0].user == "admin" + finally: + os.unlink(config_path) + + def test_load_from_file_with_model_configs(self, reset_config_store, temp_config_file, make_settings, make_model): + """ConfigStore should load model configurations.""" + settings = make_settings() + model = make_model(model_id="test-model", provider="openai") + config_path = temp_config_file(client_settings=settings, model_configs=[model]) + + try: + reset_config_store.load_from_file(config_path) + config = reset_config_store.get() + + assert config is not None + assert len(config.model_configs) == 1 + assert config.model_configs[0].id == "test-model" + finally: + os.unlink(config_path) + + def test_load_from_file_with_oci_configs( + self, reset_config_store, temp_config_file, make_settings, make_oci_config + ): + """ConfigStore should load OCI configurations.""" + settings = make_settings() + oci_config = make_oci_config(auth_profile="TEST_PROFILE") + config_path = temp_config_file(client_settings=settings, oci_configs=[oci_config]) + + try: + reset_config_store.load_from_file(config_path) + config = reset_config_store.get() + + assert config is not None + assert len(config.oci_configs) == 1 + assert config.oci_configs[0].auth_profile == "TEST_PROFILE" + finally: + os.unlink(config_path) + + def test_load_from_file_invalid_json(self, reset_config_store): + """ConfigStore should raise error for invalid JSON.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") as temp_file: + temp_file.write("not valid json {") + temp_path = Path(temp_file.name) + + try: + with pytest.raises(json.JSONDecodeError): + reset_config_store.load_from_file(temp_path) + finally: + os.unlink(temp_path) + + def test_get_returns_none_when_not_loaded(self, reset_config_store): + """ConfigStore.get() should return None when config not loaded.""" + config = reset_config_store.get() + assert config is None + + +class TestConfigFilePath: + """Tests for the config_file_path function.""" + + def test_config_file_path_returns_string(self): + """config_file_path should return a string path.""" + path = config_file_path() + assert isinstance(path, str) + + def test_config_file_path_ends_with_json(self): + """config_file_path should return a .json file path.""" + path = config_file_path() + assert path.endswith(".json") + + def test_config_file_path_contains_etc_directory(self): + """config_file_path should include etc directory.""" + path = config_file_path() + assert "etc" in path + assert "configuration.json" in path + + def test_config_file_path_is_absolute(self): + """config_file_path should return an absolute path.""" + path = config_file_path() + assert os.path.isabs(path) + + def test_config_file_path_parent_is_server_directory(self): + """config_file_path should be relative to server directory.""" + path = config_file_path() + path_obj = Path(path) + # Should be under server/etc/configuration.json + assert path_obj.parent.name == "etc" + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured in configfile module.""" + assert hasattr(configfile, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert configfile.logger.name == "bootstrap.configfile" diff --git a/test/unit/server/bootstrap/test_bootstrap_databases.py b/test/unit/server/bootstrap/test_bootstrap_databases.py new file mode 100644 index 00000000..5ab5afb6 --- /dev/null +++ b/test/unit/server/bootstrap/test_bootstrap_databases.py @@ -0,0 +1,227 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/bootstrap/databases.py +Tests for database bootstrap functionality. +""" + +# pylint: disable=redefined-outer-name protected-access too-few-public-methods + +import os + +import pytest + +from server.bootstrap import databases as databases_module +from common.schema import Database + + +@pytest.mark.usefixtures("reset_config_store", "clean_env") +class TestDatabasesMain: + """Tests for the databases.main() function.""" + + def test_main_returns_list_of_databases(self): + """main() should return a list of Database objects.""" + result = databases_module.main() + + assert isinstance(result, list) + assert all(isinstance(db, Database) for db in result) + + def test_main_creates_default_database_when_no_config(self): + """main() should create DEFAULT database when no config is loaded.""" + result = databases_module.main() + + db_names = [db.name for db in result] + assert "DEFAULT" in db_names + + def test_main_uses_env_vars_for_default_database(self): + """main() should use environment variables for DEFAULT database.""" + os.environ["DB_USERNAME"] = "env_user" + os.environ["DB_PASSWORD"] = "env_password" + os.environ["DB_DSN"] = "env_dsn:1521/ENVPDB" + os.environ["TNS_ADMIN"] = "/env/tns_admin" + + try: + result = databases_module.main() + + default_db = next(db for db in result if db.name == "DEFAULT") + assert default_db.user == "env_user" + assert default_db.password == "env_password" + assert default_db.dsn == "env_dsn:1521/ENVPDB" + assert default_db.config_dir == "/env/tns_admin" + finally: + del os.environ["DB_USERNAME"] + del os.environ["DB_PASSWORD"] + del os.environ["DB_DSN"] + del os.environ["TNS_ADMIN"] + + def test_main_sets_wallet_location_when_wallet_password_present(self): + """main() should set wallet_location when wallet_password is provided.""" + os.environ["DB_WALLET_PASSWORD"] = "wallet_pass" + os.environ["TNS_ADMIN"] = "/wallet/path" + + try: + result = databases_module.main() + + default_db = next(db for db in result if db.name == "DEFAULT") + assert default_db.wallet_password == "wallet_pass" + assert default_db.wallet_location == "/wallet/path" + finally: + del os.environ["DB_WALLET_PASSWORD"] + del os.environ["TNS_ADMIN"] + + def test_main_with_config_file_databases( + self, reset_config_store, temp_config_file, make_settings, make_database + ): + """main() should load databases from config file.""" + settings = make_settings() + db1 = make_database(name="CONFIG_DB1", user="config_user1") + db2 = make_database(name="CONFIG_DB2", user="config_user2") + config_path = temp_config_file(client_settings=settings, database_configs=[db1, db2]) + + try: + reset_config_store.load_from_file(config_path) + result = databases_module.main() + + db_names = [db.name for db in result] + assert "CONFIG_DB1" in db_names + assert "CONFIG_DB2" in db_names + finally: + os.unlink(config_path) + + def test_main_overrides_default_from_config_with_env_vars( + self, reset_config_store, temp_config_file, make_settings, make_database + ): + """main() should override DEFAULT database from config with env vars.""" + settings = make_settings() + db = make_database(name="DEFAULT", user="config_user", password="config_pass", dsn="config_dsn") + config_path = temp_config_file(client_settings=settings, database_configs=[db]) + + os.environ["DB_USERNAME"] = "env_user" + os.environ["DB_PASSWORD"] = "env_password" + + try: + reset_config_store.load_from_file(config_path) + result = databases_module.main() + + default_db = next(db for db in result if db.name == "DEFAULT") + assert default_db.user == "env_user" + assert default_db.password == "env_password" + # DSN not in env, should keep config value + assert default_db.dsn == "config_dsn" + finally: + os.unlink(config_path) + del os.environ["DB_USERNAME"] + del os.environ["DB_PASSWORD"] + + def test_main_raises_on_duplicate_database_names( + self, reset_config_store, temp_config_file, make_settings, make_database + ): + """main() should raise ValueError for duplicate database names.""" + settings = make_settings() + db1 = make_database(name="DUP_DB", user="user1") + db2 = make_database(name="dup_db", user="user2") # Case-insensitive duplicate + config_path = temp_config_file(client_settings=settings, database_configs=[db1, db2]) + + try: + reset_config_store.load_from_file(config_path) + + with pytest.raises(ValueError, match="Duplicate database name"): + databases_module.main() + finally: + os.unlink(config_path) + + def test_main_creates_default_when_not_in_config( + self, reset_config_store, temp_config_file, make_settings, make_database + ): + """main() should create DEFAULT database from env when not in config.""" + settings = make_settings() + db = make_database(name="OTHER_DB", user="other_user") + config_path = temp_config_file(client_settings=settings, database_configs=[db]) + + os.environ["DB_USERNAME"] = "default_env_user" + + try: + reset_config_store.load_from_file(config_path) + result = databases_module.main() + + db_names = [db.name for db in result] + assert "DEFAULT" in db_names + assert "OTHER_DB" in db_names + + default_db = next(db for db in result if db.name == "DEFAULT") + assert default_db.user == "default_env_user" + finally: + os.unlink(config_path) + del os.environ["DB_USERNAME"] + + def test_main_handles_case_insensitive_default_name( + self, reset_config_store, temp_config_file, make_settings, make_database + ): + """main() should handle DEFAULT name case-insensitively.""" + settings = make_settings() + db = make_database(name="default", user="config_user") # lowercase + config_path = temp_config_file(client_settings=settings, database_configs=[db]) + + os.environ["DB_USERNAME"] = "env_user" + + try: + reset_config_store.load_from_file(config_path) + result = databases_module.main() + + # Should find and update the lowercase "default" + default_db = next(db for db in result if db.name.upper() == "DEFAULT") + assert default_db.user == "env_user" + finally: + os.unlink(config_path) + del os.environ["DB_USERNAME"] + + def test_main_preserves_non_default_databases_unchanged( + self, reset_config_store, temp_config_file, make_settings, make_database + ): + """main() should not modify non-DEFAULT databases.""" + settings = make_settings() + db = make_database(name="CUSTOM_DB", user="custom_user", password="custom_pass") + config_path = temp_config_file(client_settings=settings, database_configs=[db]) + + os.environ["DB_USERNAME"] = "should_not_apply" + + try: + reset_config_store.load_from_file(config_path) + result = databases_module.main() + + custom_db = next(db for db in result if db.name == "CUSTOM_DB") + assert custom_db.user == "custom_user" + assert custom_db.password == "custom_pass" + finally: + os.unlink(config_path) + del os.environ["DB_USERNAME"] + + def test_main_default_config_dir_fallback(self): + """main() should use 'tns_admin' as default config_dir when not specified.""" + result = databases_module.main() + + default_db = next(db for db in result if db.name == "DEFAULT") + assert default_db.config_dir == "tns_admin" + + +@pytest.mark.usefixtures("reset_config_store", "clean_env") +class TestDatabasesMainAsScript: + """Tests for running databases module as script.""" + + def test_main_callable_directly(self): + """main() should be callable when running as script.""" + result = databases_module.main() + assert result is not None + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured in databases module.""" + assert hasattr(databases_module, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert databases_module.logger.name == "bootstrap.databases" diff --git a/test/unit/server/bootstrap/test_bootstrap_models.py b/test/unit/server/bootstrap/test_bootstrap_models.py new file mode 100644 index 00000000..27728c5d --- /dev/null +++ b/test/unit/server/bootstrap/test_bootstrap_models.py @@ -0,0 +1,418 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/bootstrap/models.py +Tests for model bootstrap functionality. +""" + +# pylint: disable=redefined-outer-name protected-access too-few-public-methods + +import os +from unittest.mock import patch + +import pytest + +from server.bootstrap import models as models_module +from common.schema import Model + + +@pytest.mark.usefixtures("reset_config_store", "clean_env", "mock_is_url_accessible") +class TestModelsMain: + """Tests for the models.main() function.""" + + def test_main_returns_list_of_models(self): + """main() should return a list of Model objects.""" + result = models_module.main() + + assert isinstance(result, list) + assert all(isinstance(m, Model) for m in result) + + def test_main_includes_base_models(self): + """main() should include base model configurations.""" + result = models_module.main() + + model_ids = [m.id for m in result] + # Should include at least some base models + assert "gpt-4o-mini" in model_ids + assert "command-r" in model_ids + + def test_main_enables_models_with_api_keys(self): + """main() should enable models when API keys are present.""" + os.environ["OPENAI_API_KEY"] = "test-openai-key" + + try: + result = models_module.main() + + openai_model = next(m for m in result if m.id == "gpt-4o-mini") + assert openai_model.enabled is True + assert openai_model.api_key == "test-openai-key" + finally: + del os.environ["OPENAI_API_KEY"] + + def test_main_disables_models_without_api_keys(self): + """main() should disable models when API keys are not present.""" + result = models_module.main() + + openai_model = next(m for m in result if m.id == "gpt-4o-mini") + assert openai_model.enabled is False + + @pytest.mark.usefixtures("reset_config_store", "clean_env") + def test_main_checks_url_accessibility(self): + """main() should check URL accessibility for enabled models.""" + os.environ["OPENAI_API_KEY"] = "test-key" + + with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: + mock_accessible.return_value = (False, "Connection refused") + + try: + result = models_module.main() + + # Model should be disabled if URL is not accessible + openai_model = next(m for m in result if m.id == "gpt-4o-mini") + assert openai_model.enabled is False + mock_accessible.assert_called() + finally: + del os.environ["OPENAI_API_KEY"] + + @pytest.mark.usefixtures("reset_config_store", "clean_env") + def test_main_caches_url_accessibility_results(self): + """main() should cache URL accessibility results for same URLs.""" + os.environ["OPENAI_API_KEY"] = "test-key" + os.environ["COHERE_API_KEY"] = "test-key" + + with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: + mock_accessible.return_value = (True, "OK") + + try: + models_module.main() + + # Multiple models share the same base URL, should only check once per URL + call_urls = [call[0][0] for call in mock_accessible.call_args_list] + # Should not have duplicate URL checks + assert len(call_urls) == len(set(call_urls)) + finally: + del os.environ["OPENAI_API_KEY"] + del os.environ["COHERE_API_KEY"] + + +@pytest.mark.usefixtures("clean_env") +class TestGetBaseModelsList: + """Tests for the _get_base_models_list function.""" + + def test_returns_list_of_dicts(self): + """_get_base_models_list should return a list of dictionaries.""" + result = models_module._get_base_models_list() + + assert isinstance(result, list) + assert all(isinstance(m, dict) for m in result) + + def test_includes_required_fields(self): + """_get_base_models_list should include required fields for each model.""" + result = models_module._get_base_models_list() + + for model in result: + assert "id" in model + assert "type" in model + assert "provider" in model + assert "api_base" in model + + def test_includes_ll_and_embed_models(self): + """_get_base_models_list should include both LLM and embedding models.""" + result = models_module._get_base_models_list() + + types = {m["type"] for m in result} + assert "ll" in types + assert "embed" in types + + +class TestCheckForDuplicates: + """Tests for the _check_for_duplicates function.""" + + def test_no_error_for_unique_models(self): + """_check_for_duplicates should not raise for unique models.""" + models_list = [ + {"id": "model1", "provider": "openai"}, + {"id": "model2", "provider": "openai"}, + {"id": "model1", "provider": "cohere"}, # Same ID, different provider + ] + + # Should not raise + models_module._check_for_duplicates(models_list) + + def test_raises_for_duplicate_models(self): + """_check_for_duplicates should raise ValueError for duplicates.""" + models_list = [ + {"id": "model1", "provider": "openai"}, + {"id": "model1", "provider": "openai"}, # Duplicate + ] + + with pytest.raises(ValueError, match="already exists"): + models_module._check_for_duplicates(models_list) + + +class TestValuesDiffer: + """Tests for the _values_differ function.""" + + def test_bool_comparison(self): + """_values_differ should handle boolean comparisons.""" + assert models_module._values_differ(True, False) is True + assert models_module._values_differ(True, True) is False + assert models_module._values_differ(False, False) is False + + def test_numeric_comparison(self): + """_values_differ should handle numeric comparisons.""" + assert models_module._values_differ(1, 2) is True + assert models_module._values_differ(1.0, 1.0) is False + assert models_module._values_differ(1, 1.0) is False + # Small float differences should be considered equal + assert models_module._values_differ(1.0, 1.0 + 1e-9) is False + assert models_module._values_differ(1.0, 1.1) is True + + def test_string_comparison(self): + """_values_differ should handle string comparisons with strip.""" + assert models_module._values_differ("test", "test") is False + assert models_module._values_differ(" test ", "test") is False + assert models_module._values_differ("test", "other") is True + + def test_general_comparison(self): + """_values_differ should handle general equality comparison.""" + assert models_module._values_differ([1, 2], [1, 2]) is False + assert models_module._values_differ([1, 2], [1, 3]) is True + assert models_module._values_differ(None, None) is False + assert models_module._values_differ(None, "value") is True + + +@pytest.mark.usefixtures("reset_config_store") +class TestMergeWithConfigStore: + """Tests for the _merge_with_config_store function.""" + + def test_returns_unchanged_when_no_config(self): + """_merge_with_config_store should return unchanged list when no config.""" + models_list = [{"id": "model1", "provider": "openai", "enabled": False}] + + result = models_module._merge_with_config_store(models_list) + + assert result == models_list + + def test_merges_config_store_models( + self, reset_config_store, temp_config_file, make_settings, make_model + ): + """_merge_with_config_store should merge models from ConfigStore.""" + settings = make_settings() + config_model = make_model(model_id="config-model", provider="custom") + config_path = temp_config_file(client_settings=settings, model_configs=[config_model]) + + models_list = [{"id": "existing", "provider": "openai", "enabled": False}] + + try: + reset_config_store.load_from_file(config_path) + result = models_module._merge_with_config_store(models_list) + + model_keys = [(m["provider"], m["id"]) for m in result] + assert ("custom", "config-model") in model_keys + assert ("openai", "existing") in model_keys + finally: + os.unlink(config_path) + + def test_overrides_existing_model_values( + self, reset_config_store, temp_config_file, make_settings, make_model + ): + """_merge_with_config_store should override existing model values.""" + settings = make_settings() + config_model = make_model(model_id="existing", provider="openai", enabled=True) + config_path = temp_config_file(client_settings=settings, model_configs=[config_model]) + + models_list = [ + {"id": "existing", "provider": "openai", "enabled": False, "api_base": "https://api.openai.com/v1"} + ] + + try: + reset_config_store.load_from_file(config_path) + result = models_module._merge_with_config_store(models_list) + + merged_model = next(m for m in result if m["id"] == "existing") + assert merged_model["enabled"] is True + finally: + os.unlink(config_path) + + +class ModelDict(dict): + """Dict subclass that also supports attribute access for 'id'. + + The _update_env_var function in models.py uses both dict-style (.get(), []) + and attribute-style (.id) access, so tests need objects that support both. + """ + + def __getattr__(self, name): + if name in self: + return self[name] + raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") + + +@pytest.mark.usefixtures("clean_env") +class TestApplyEnvVarOverrides: + """Tests for the _apply_env_var_overrides function.""" + + def test_applies_cohere_api_key(self): + """_apply_env_var_overrides should apply COHERE_API_KEY.""" + # Use ModelDict to support both dict and attribute access (needed for model.id) + models_list = [ModelDict({"id": "command-r", "provider": "cohere", "api_key": "original"})] + os.environ["COHERE_API_KEY"] = "env-key" + + try: + models_module._apply_env_var_overrides(models_list) + + assert models_list[0]["api_key"] == "env-key" + finally: + del os.environ["COHERE_API_KEY"] + + def test_applies_ollama_url(self): + """_apply_env_var_overrides should apply ON_PREM_OLLAMA_URL.""" + models_list = [ModelDict({"id": "llama3.1", "provider": "ollama", "api_base": "http://localhost:11434"})] + os.environ["ON_PREM_OLLAMA_URL"] = "http://custom:11434" + + try: + models_module._apply_env_var_overrides(models_list) + + assert models_list[0]["api_base"] == "http://custom:11434" + finally: + del os.environ["ON_PREM_OLLAMA_URL"] + + def test_does_not_apply_to_wrong_provider(self): + """_apply_env_var_overrides should not apply overrides to wrong provider.""" + models_list = [ModelDict({"id": "gpt-4o-mini", "provider": "openai", "api_key": "original"})] + os.environ["COHERE_API_KEY"] = "env-key" + + try: + models_module._apply_env_var_overrides(models_list) + + assert models_list[0]["api_key"] == "original" + finally: + del os.environ["COHERE_API_KEY"] + + +@pytest.mark.usefixtures("clean_env") +class TestUpdateEnvVar: + """Tests for the _update_env_var function. + + Note: _update_env_var uses dict-style access (.get(), []) but also accesses + model.id directly for logging. Use ModelDict for compatibility. + """ + + def test_updates_matching_provider(self): + """_update_env_var should update model when provider matches.""" + model = ModelDict({"id": "gpt-4o-mini", "provider": "openai", "api_key": "old"}) + os.environ["TEST_KEY"] = "new" + + try: + models_module._update_env_var(model, "openai", "api_key", "TEST_KEY") + + assert model["api_key"] == "new" + finally: + del os.environ["TEST_KEY"] + + def test_ignores_non_matching_provider(self): + """_update_env_var should not update when provider doesn't match.""" + model = ModelDict({"id": "command-r", "provider": "cohere", "api_key": "old"}) + os.environ["TEST_KEY"] = "new" + + try: + models_module._update_env_var(model, "openai", "api_key", "TEST_KEY") + + assert model["api_key"] == "old" + finally: + del os.environ["TEST_KEY"] + + def test_ignores_when_env_var_not_set(self): + """_update_env_var should not update when env var is not set.""" + model = ModelDict({"id": "gpt-4o-mini", "provider": "openai", "api_key": "old"}) + + models_module._update_env_var(model, "openai", "api_key", "NONEXISTENT_VAR") + + assert model["api_key"] == "old" + + def test_ignores_when_value_unchanged(self): + """_update_env_var should not update when value is the same.""" + model = ModelDict({"id": "gpt-4o-mini", "provider": "openai", "api_key": "same"}) + os.environ["TEST_KEY"] = "same" + + try: + models_module._update_env_var(model, "openai", "api_key", "TEST_KEY") + + assert model["api_key"] == "same" + finally: + del os.environ["TEST_KEY"] + + +@pytest.mark.usefixtures("clean_env") +class TestCheckUrlAccessibility: + """Tests for the _check_url_accessibility function.""" + + def test_disables_inaccessible_urls(self): + """_check_url_accessibility should disable models with inaccessible URLs.""" + models_list = [{"id": "test", "api_base": "http://localhost:1234", "enabled": True}] + + with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: + mock_accessible.return_value = (False, "Connection refused") + + models_module._check_url_accessibility(models_list) + + assert models_list[0]["enabled"] is False + + def test_keeps_accessible_urls_enabled(self): + """_check_url_accessibility should keep models with accessible URLs enabled.""" + models_list = [{"id": "test", "api_base": "http://localhost:1234", "enabled": True}] + + with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: + mock_accessible.return_value = (True, "OK") + + models_module._check_url_accessibility(models_list) + + assert models_list[0]["enabled"] is True + + def test_skips_disabled_models(self): + """_check_url_accessibility should skip models that are already disabled.""" + models_list = [{"id": "test", "api_base": "http://localhost:1234", "enabled": False}] + + with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: + models_module._check_url_accessibility(models_list) + + mock_accessible.assert_not_called() + + def test_caches_url_results(self): + """_check_url_accessibility should cache results for the same URL.""" + models_list = [ + {"id": "test1", "api_base": "http://localhost:1234", "enabled": True}, + {"id": "test2", "api_base": "http://localhost:1234", "enabled": True}, + ] + + with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: + mock_accessible.return_value = (True, "OK") + + models_module._check_url_accessibility(models_list) + + # Should only be called once for the shared URL + assert mock_accessible.call_count == 1 + + +@pytest.mark.usefixtures("reset_config_store", "clean_env", "mock_is_url_accessible") +class TestModelsMainAsScript: + """Tests for running models module as script.""" + + def test_main_callable_directly(self): + """main() should be callable when running as script.""" + result = models_module.main() + assert result is not None + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured in models module.""" + assert hasattr(models_module, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert models_module.logger.name == "bootstrap.models" diff --git a/test/unit/server/bootstrap/test_bootstrap_oci.py b/test/unit/server/bootstrap/test_bootstrap_oci.py new file mode 100644 index 00000000..89242c8e --- /dev/null +++ b/test/unit/server/bootstrap/test_bootstrap_oci.py @@ -0,0 +1,329 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/bootstrap/oci.py +Tests for OCI bootstrap functionality. +""" + +# pylint: disable=redefined-outer-name protected-access too-few-public-methods + +import os +from unittest.mock import patch, MagicMock + +import pytest +import oci + +from server.bootstrap import oci as oci_module +from common.schema import OracleCloudSettings + + +@pytest.mark.usefixtures("reset_config_store", "clean_env") +class TestOciMain: + """Tests for the oci.main() function.""" + + def test_main_returns_list_of_oci_settings(self): + """main() should return a list of OracleCloudSettings objects.""" + with patch("oci.config.from_file", side_effect=oci.exceptions.ConfigFileNotFound()): + result = oci_module.main() + + assert isinstance(result, list) + assert all(isinstance(s, OracleCloudSettings) for s in result) + + def test_main_creates_default_profile_when_no_config(self): + """main() should create DEFAULT profile when no OCI config exists.""" + with patch("oci.config.from_file", side_effect=oci.exceptions.ConfigFileNotFound()): + result = oci_module.main() + + profile_names = [s.auth_profile for s in result] + assert oci.config.DEFAULT_PROFILE in profile_names + + def test_main_reads_oci_config_file(self): + """main() should read from OCI config file when it exists.""" + # User OCID must match pattern ^([0-9a-zA-Z-_]+[.:])([0-9a-zA-Z-_]*[.:]){3,}([0-9a-zA-Z-_]+)$ + mock_config_data = { + "tenancy": "ocid1.tenancy.oc1..test123", + "region": "us-phoenix-1", + "user": "ocid1.user.oc1..test123", # Valid OCID pattern + "fingerprint": "test-fingerprint", + "key_file": "/path/to/key.pem", + } + + with patch("configparser.ConfigParser") as mock_parser: + mock_instance = MagicMock() + mock_instance.sections.return_value = [] + mock_parser.return_value = mock_instance + + with patch("oci.config.from_file", return_value=mock_config_data.copy()): + result = oci_module.main() + + assert len(result) >= 1 + default_profile = next((p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE), None) + assert default_profile is not None + + def test_main_applies_env_var_overrides_to_default(self): + """main() should apply environment variable overrides to DEFAULT profile.""" + # User OCID must match pattern ^([0-9a-zA-Z-_]+[.:])([0-9a-zA-Z-_]*[.:]){3,}([0-9a-zA-Z-_]+)$ + os.environ["OCI_CLI_TENANCY"] = "env-tenancy" + os.environ["OCI_CLI_REGION"] = "us-chicago-1" + os.environ["OCI_CLI_USER"] = "ocid1.user.oc1..envuser123" # Valid OCID pattern + + try: + with patch("oci.config.from_file", side_effect=oci.exceptions.ConfigFileNotFound()): + result = oci_module.main() + + default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) + assert default_profile.tenancy == "env-tenancy" + assert default_profile.region == "us-chicago-1" + assert default_profile.user == "ocid1.user.oc1..envuser123" + finally: + del os.environ["OCI_CLI_TENANCY"] + del os.environ["OCI_CLI_REGION"] + del os.environ["OCI_CLI_USER"] + + def test_main_env_overrides_genai_settings(self): + """main() should apply GenAI environment variable overrides.""" + # genai_compartment_id must match OCID pattern + os.environ["OCI_GENAI_COMPARTMENT_ID"] = "ocid1.compartment.oc1..genaitest" + os.environ["OCI_GENAI_REGION"] = "us-chicago-1" + + try: + with patch("oci.config.from_file", side_effect=oci.exceptions.ConfigFileNotFound()): + result = oci_module.main() + + default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) + assert default_profile.genai_compartment_id == "ocid1.compartment.oc1..genaitest" + assert default_profile.genai_region == "us-chicago-1" + finally: + del os.environ["OCI_GENAI_COMPARTMENT_ID"] + del os.environ["OCI_GENAI_REGION"] + + def test_main_security_token_authentication(self): + """main() should set authentication based on security_token_file in profile. + + Note: Due to how profile.update() works, the authentication logic reads the + OLD value of security_token_file before the update completes. If security_token_file + is already set in the profile, authentication becomes 'security_token'. + For env var alone without existing profile value, use OCI_CLI_AUTH instead. + """ + # To get security_token auth, we need OCI_CLI_AUTH explicitly set + # OR we need security_token_file already in the profile before overrides + os.environ["OCI_CLI_SECURITY_TOKEN_FILE"] = "/path/to/token" + os.environ["OCI_CLI_AUTH"] = "security_token" # Must explicitly set + + try: + with patch("oci.config.from_file", side_effect=oci.exceptions.ConfigFileNotFound()): + result = oci_module.main() + + default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) + assert default_profile.authentication == "security_token" + assert default_profile.security_token_file == "/path/to/token" + finally: + del os.environ["OCI_CLI_SECURITY_TOKEN_FILE"] + del os.environ["OCI_CLI_AUTH"] + + def test_main_explicit_auth_env_var(self): + """main() should use OCI_CLI_AUTH env var when specified.""" + os.environ["OCI_CLI_AUTH"] = "instance_principal" + + try: + with patch("oci.config.from_file", side_effect=oci.exceptions.ConfigFileNotFound()): + result = oci_module.main() + + default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) + assert default_profile.authentication == "instance_principal" + finally: + del os.environ["OCI_CLI_AUTH"] + + def test_main_loads_multiple_profiles(self): + """main() should load multiple profiles from OCI config.""" + profiles = ["PROFILE1", "PROFILE2"] + + with patch("configparser.ConfigParser") as mock_parser: + mock_instance = MagicMock() + mock_instance.sections.return_value = profiles + mock_parser.return_value = mock_instance + + def mock_from_file(**kwargs): + profile_name = kwargs.get("profile_name") + # User must be None or valid OCID pattern + return { + "tenancy": f"tenancy-{profile_name}", + "region": "us-ashburn-1", + "fingerprint": "fingerprint", + "key_file": "/path/to/key.pem", + } + + with patch("oci.config.from_file", side_effect=mock_from_file): + result = oci_module.main() + + profile_names = [p.auth_profile for p in result] + assert "PROFILE1" in profile_names + assert "PROFILE2" in profile_names + + def test_main_handles_invalid_key_file_path(self): + """main() should skip profiles with invalid key file paths.""" + profiles = ["VALID", "INVALID"] + + with patch("configparser.ConfigParser") as mock_parser: + mock_instance = MagicMock() + mock_instance.sections.return_value = profiles + mock_parser.return_value = mock_instance + + def mock_from_file(**kwargs): + profile_name = kwargs.get("profile_name") + if profile_name == "INVALID": + raise oci.exceptions.InvalidKeyFilePath("Invalid key file") + # User must be None or valid OCID pattern + return { + "tenancy": "tenancy", + "region": "us-ashburn-1", + "fingerprint": "fingerprint", + "key_file": "/path/to/key.pem", + } + + with patch("oci.config.from_file", side_effect=mock_from_file): + result = oci_module.main() + + profile_names = [p.auth_profile for p in result] + assert "VALID" in profile_names + # INVALID should be skipped, DEFAULT should be created + + def test_main_merges_config_store_oci_configs( + self, reset_config_store, temp_config_file, make_settings, make_oci_config + ): + """main() should merge OCI configs from ConfigStore.""" + settings = make_settings() + oci_config = make_oci_config(auth_profile="CONFIG_PROFILE", tenancy="config-tenancy") + config_path = temp_config_file(client_settings=settings, oci_configs=[oci_config]) + + try: + with patch("oci.config.from_file", side_effect=oci.exceptions.ConfigFileNotFound()): + reset_config_store.load_from_file(config_path) + result = oci_module.main() + + profile_names = [p.auth_profile for p in result] + assert "CONFIG_PROFILE" in profile_names + + config_profile = next(p for p in result if p.auth_profile == "CONFIG_PROFILE") + assert config_profile.tenancy == "config-tenancy" + finally: + os.unlink(config_path) + + def test_main_config_store_overrides_existing_profile( + self, reset_config_store, temp_config_file, make_settings, make_oci_config + ): + """main() should override existing profiles with ConfigStore configs.""" + settings = make_settings() + oci_config = make_oci_config(auth_profile=oci.config.DEFAULT_PROFILE, tenancy="override-tenancy") + config_path = temp_config_file(client_settings=settings, oci_configs=[oci_config]) + + # User must be None or valid OCID pattern + mock_file_config = { + "tenancy": "file-tenancy", + "region": "us-ashburn-1", + "fingerprint": "fingerprint", + "key_file": "/path/to/key.pem", + } + + try: + with patch("configparser.ConfigParser") as mock_parser: + mock_instance = MagicMock() + mock_instance.sections.return_value = [] + mock_parser.return_value = mock_instance + + with patch("oci.config.from_file", return_value=mock_file_config.copy()): + reset_config_store.load_from_file(config_path) + result = oci_module.main() + + default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) + # ConfigStore should override file config + assert default_profile.tenancy == "override-tenancy" + finally: + os.unlink(config_path) + + def test_main_uses_custom_config_file_path(self): + """main() should use OCI_CLI_CONFIG_FILE env var for config path.""" + custom_path = "/custom/oci/config" + os.environ["OCI_CLI_CONFIG_FILE"] = custom_path + + try: + with patch("configparser.ConfigParser") as mock_parser: + mock_instance = MagicMock() + mock_instance.sections.return_value = [] + mock_parser.return_value = mock_instance + + with patch("oci.config.from_file", side_effect=oci.exceptions.ConfigFileNotFound()): + result = oci_module.main() + + # The expanded path should be used + assert len(result) >= 1 + finally: + del os.environ["OCI_CLI_CONFIG_FILE"] + + +@pytest.mark.usefixtures("clean_env") +class TestApplyEnvOverrides: + """Tests for the _apply_env_overrides_to_default_profile function.""" + + def test_override_function_modifies_default_profile(self): + """_apply_env_overrides_to_default_profile should modify DEFAULT profile.""" + config = [{"auth_profile": oci.config.DEFAULT_PROFILE, "tenancy": "original"}] + + os.environ["OCI_CLI_TENANCY"] = "overridden" + + try: + oci_module._apply_env_overrides_to_default_profile(config) + + assert config[0]["tenancy"] == "overridden" + finally: + del os.environ["OCI_CLI_TENANCY"] + + def test_override_function_ignores_non_default_profiles(self): + """_apply_env_overrides_to_default_profile should not modify non-DEFAULT profiles.""" + config = [{"auth_profile": "CUSTOM", "tenancy": "original"}] + + os.environ["OCI_CLI_TENANCY"] = "overridden" + + try: + oci_module._apply_env_overrides_to_default_profile(config) + + assert config[0]["tenancy"] == "original" + finally: + del os.environ["OCI_CLI_TENANCY"] + + def test_override_logs_changes(self, caplog): + """_apply_env_overrides_to_default_profile should log overrides.""" + config = [{"auth_profile": oci.config.DEFAULT_PROFILE, "tenancy": "original"}] + + os.environ["OCI_CLI_TENANCY"] = "new-tenancy" + + try: + oci_module._apply_env_overrides_to_default_profile(config) + + assert "Environment variable overrides" in caplog.text or "new-tenancy" in str(config) + finally: + del os.environ["OCI_CLI_TENANCY"] + + +@pytest.mark.usefixtures("reset_config_store", "clean_env") +class TestOciMainAsScript: + """Tests for running OCI module as script.""" + + def test_main_callable_directly(self): + """main() should be callable when running as script.""" + with patch("oci.config.from_file", side_effect=oci.exceptions.ConfigFileNotFound()): + result = oci_module.main() + assert result is not None + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured in oci module.""" + assert hasattr(oci_module, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert oci_module.logger.name == "bootstrap.oci" diff --git a/test/unit/server/bootstrap/test_bootstrap_settings.py b/test/unit/server/bootstrap/test_bootstrap_settings.py new file mode 100644 index 00000000..5bac59b8 --- /dev/null +++ b/test/unit/server/bootstrap/test_bootstrap_settings.py @@ -0,0 +1,143 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/bootstrap/settings.py +Tests for settings bootstrap functionality. +""" + +# pylint: disable=redefined-outer-name protected-access too-few-public-methods + +import os +from unittest.mock import patch, MagicMock + +import pytest + +from server.bootstrap import settings as settings_module +from common.schema import Settings + + +@pytest.mark.usefixtures("reset_config_store") +class TestSettingsMain: + """Tests for the settings.main() function.""" + + def test_main_returns_list_of_settings(self): + """main() should return a list of Settings objects.""" + result = settings_module.main() + + assert isinstance(result, list) + assert all(isinstance(s, Settings) for s in result) + + def test_main_creates_default_and_server_clients(self): + """main() should create settings for 'default' and 'server' clients.""" + result = settings_module.main() + + client_names = [s.client for s in result] + assert "default" in client_names + assert "server" in client_names + assert len(result) == 2 + + def test_main_without_config_uses_default_settings(self): + """main() should use default Settings when no config is loaded.""" + result = settings_module.main() + + # Both should have default Settings values + for s in result: + assert isinstance(s, Settings) + assert s.client in ["default", "server"] + + def test_main_with_config_uses_config_settings(self, reset_config_store, temp_config_file, make_settings): + """main() should use config file settings when available.""" + # Create settings with custom values + custom_settings = make_settings(client="config_client") + custom_settings.ll_model.temperature = 0.9 + custom_settings.ll_model.max_tokens = 8192 + + config_path = temp_config_file(client_settings=custom_settings) + + try: + reset_config_store.load_from_file(config_path) + result = settings_module.main() + + # Both clients should inherit from config settings + for s in result: + assert s.ll_model.temperature == 0.9 + assert s.ll_model.max_tokens == 8192 + # Client name should be overridden to default/server + assert s.client in ["default", "server"] + finally: + os.unlink(config_path) + + def test_main_preserves_client_names_from_base_list(self, reset_config_store, temp_config_file, make_settings): + """main() should override client field from config with base client names.""" + custom_settings = make_settings(client="original_name") + config_path = temp_config_file(client_settings=custom_settings) + + try: + reset_config_store.load_from_file(config_path) + result = settings_module.main() + + # Client names should be "default" and "server", not "original_name" + client_names = [s.client for s in result] + assert "original_name" not in client_names + assert "default" in client_names + assert "server" in client_names + finally: + os.unlink(config_path) + + def test_main_with_config_but_no_client_settings(self, reset_config_store): + """main() should use default Settings when config has no client_settings.""" + mock_config = MagicMock() + mock_config.client_settings = None + + with patch.object(reset_config_store, "get", return_value=mock_config): + result = settings_module.main() + + assert len(result) == 2 + assert all(isinstance(s, Settings) for s in result) + + def test_main_creates_copies_with_different_clients(self, reset_config_store, temp_config_file, make_settings): + """main() should create separate Settings objects with unique client names. + + Note: Pydantic's model_copy() creates shallow copies by default, + so nested objects (like ll_model) may be shared. However, the top-level + Settings objects should be distinct with their own 'client' values. + """ + custom_settings = make_settings(client="config_client") + config_path = temp_config_file(client_settings=custom_settings) + + try: + reset_config_store.load_from_file(config_path) + result = settings_module.main() + + # The Settings objects themselves should be distinct + assert result[0] is not result[1] + # And have different client names + assert result[0].client != result[1].client + assert result[0].client in ["default", "server"] + assert result[1].client in ["default", "server"] + finally: + os.unlink(config_path) + + +@pytest.mark.usefixtures("reset_config_store") +class TestSettingsMainAsScript: + """Tests for running settings module as script.""" + + def test_main_callable_directly(self): + """main() should be callable when running as script.""" + # This tests the if __name__ == "__main__" block indirectly + result = settings_module.main() + assert result is not None + + +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_logger_exists(self): + """Logger should be configured in settings module.""" + assert hasattr(settings_module, "logger") + + def test_logger_name(self): + """Logger should have correct name.""" + assert settings_module.logger.name == "bootstrap.settings" diff --git a/tests/server/unit/bootstrap/test_bootstrap.py b/tests/server/unit/bootstrap/test_bootstrap.py index 9caedd01..9f28e5ed 100644 --- a/tests/server/unit/bootstrap/test_bootstrap.py +++ b/tests/server/unit/bootstrap/test_bootstrap.py @@ -3,48 +3,20 @@ Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ # spell-checker: disable -# pylint: disable=protected-access import-error import-outside-toplevel -import importlib -from unittest.mock import patch, MagicMock - -from server.bootstrap import bootstrap - - -class TestBootstrap: - """Test bootstrap module functionality""" - - @patch("server.bootstrap.databases.main") - @patch("server.bootstrap.models.main") - @patch("server.bootstrap.oci.main") - @patch("server.bootstrap.settings.main") - def test_module_imports_and_initialization( - self, mock_settings, mock_oci, mock_models, mock_databases - ): - """Test that all bootstrap objects are properly initialized""" - # Mock return values - mock_databases.return_value = [MagicMock()] - mock_models.return_value = [MagicMock()] - mock_oci.return_value = [MagicMock()] - mock_settings.return_value = [MagicMock()] - - # Reload the module to trigger initialization - - importlib.reload(bootstrap) - - # Verify all bootstrap functions were called - mock_databases.assert_called_once() - mock_models.assert_called_once() - mock_oci.assert_called_once() - mock_settings.assert_called_once() - - # Verify objects are created - assert hasattr(bootstrap, "DATABASE_OBJECTS") - assert hasattr(bootstrap, "MODEL_OBJECTS") - assert hasattr(bootstrap, "OCI_OBJECTS") - assert hasattr(bootstrap, "SETTINGS_OBJECTS") - - def test_logger_exists(self): - """Test that logger is properly configured""" - assert hasattr(bootstrap, "logger") - assert bootstrap.logger.name == "bootstrap" +# ============================================================================= +# DEPRECATED: Tests in this file have been replaced by more comprehensive tests +# in test/unit/server/bootstrap/test_bootstrap_bootstrap.py +# ============================================================================= +# +# test_module_imports_and_initialization -> Replaced by: +# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestBootstrapModule::test_calls_all_bootstrap_functions +# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestBootstrapModule::test_database_objects_is_list +# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestBootstrapModule::test_model_objects_is_list +# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestBootstrapModule::test_oci_objects_is_list +# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestBootstrapModule::test_settings_objects_is_list +# +# test_logger_exists -> Replaced by: +# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestLoggerConfiguration::test_logger_exists +# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestLoggerConfiguration::test_logger_name +# From 73b66098fafe891ed8e55767341283f9efeb2ddd Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 18:47:30 +0000 Subject: [PATCH 09/20] refresh tests --- .gitignore | 2 +- src/server/bootstrap/configfile.py | 6 + test/integration/server/api/conftest.py | 49 +-- test/integration/server/bootstrap/conftest.py | 83 +---- .../bootstrap/test_bootstrap_databases.py | 42 +-- .../server/bootstrap/test_bootstrap_models.py | 33 +- test/shared_fixtures.py | 330 ++++++++++++++++++ test/unit/server/api/conftest.py | 171 ++------- test/unit/server/bootstrap/conftest.py | 265 +------------- .../bootstrap/test_bootstrap_databases.py | 102 +++--- .../server/bootstrap/test_bootstrap_models.py | 29 +- 11 files changed, 484 insertions(+), 628 deletions(-) create mode 100644 test/shared_fixtures.py diff --git a/.gitignore b/.gitignore index 88819a6c..d7cc6fb5 100644 --- a/.gitignore +++ b/.gitignore @@ -21,7 +21,7 @@ sbin/** !opentofu/examples/manual-test.sh !src/entrypoint.sh !src/client/spring_ai/templates/env.sh -tests/db_startup_temp/** +test*/db_startup_temp ############################################################################## # Environment (PyVen, IDE, etc.) diff --git a/src/server/bootstrap/configfile.py b/src/server/bootstrap/configfile.py index 2dc3dbcb..e713078c 100644 --- a/src/server/bootstrap/configfile.py +++ b/src/server/bootstrap/configfile.py @@ -45,6 +45,12 @@ def get(cls): """Return the configuration stored in memory""" return cls._config + @classmethod + def reset(cls): + """Reset the configuration state. Used for testing.""" + with cls._lock: + cls._config = None + def config_file_path() -> str: """Return the path where settings should be stored.""" diff --git a/test/integration/server/api/conftest.py b/test/integration/server/api/conftest.py index 045d68e0..6b25297f 100644 --- a/test/integration/server/api/conftest.py +++ b/test/integration/server/api/conftest.py @@ -10,19 +10,24 @@ Note: db_container fixture is inherited from test/conftest.py - do not import here. """ -# pylint: disable=redefined-outer-name +# pylint: disable=redefined-outer-name unused-import # Pytest fixtures use parameter injection where fixture names match parameters import os import asyncio from typing import Generator +# Re-export shared fixtures for pytest discovery (before third-party imports per pylint) from test.db_fixtures import TEST_DB_CONFIG +from test.shared_fixtures import ( + make_database, + make_model, + DEFAULT_LL_MODEL_CONFIG, +) import pytest from fastapi.testclient import TestClient -from common.schema import Database, Model from server.bootstrap.bootstrap import DATABASE_OBJECTS, MODEL_OBJECTS, SETTINGS_OBJECTS @@ -119,48 +124,10 @@ def sample_settings_payload(): """Sample settings configuration for testing.""" return { "client": TEST_CONFIG["client"], - "ll_model": { - "model": "gpt-4o-mini", - "temperature": 0.7, - "max_tokens": 4096, - "chat_history": True, - }, + "ll_model": DEFAULT_LL_MODEL_CONFIG.copy(), } -################################################# -# Schema Factory Fixtures -################################################# -@pytest.fixture -def make_database(): - """Factory fixture for creating Database objects.""" - def _make_database(**kwargs): - defaults = { - "name": "TEST_DB", - "user": "test_user", - "password": "test_password", - "dsn": "localhost:1521/TEST", - } - defaults.update(kwargs) - return Database(**defaults) - return _make_database - - -@pytest.fixture -def make_model(): - """Factory fixture for creating Model objects.""" - def _make_model(**kwargs): - defaults = { - "id": "test-model", - "type": "ll", - "provider": "openai", - "enabled": True, - } - defaults.update(kwargs) - return Model(**defaults) - return _make_model - - ################################################# # State Management Helpers ################################################# diff --git a/test/integration/server/bootstrap/conftest.py b/test/integration/server/bootstrap/conftest.py index 4bea8405..00848e66 100644 --- a/test/integration/server/bootstrap/conftest.py +++ b/test/integration/server/bootstrap/conftest.py @@ -9,16 +9,24 @@ verify end-to-end behavior of the bootstrap system. """ -# pylint: disable=redefined-outer-name protected-access +# pylint: disable=redefined-outer-name unused-import import json -import os import tempfile from pathlib import Path +# Re-export shared fixtures for pytest discovery +from test.shared_fixtures import ( + reset_config_store, + clean_env, + BOOTSTRAP_ENV_VARS, + DEFAULT_LL_MODEL_CONFIG, +) + import pytest -from server.bootstrap.configfile import ConfigStore +# Alias for backwards compatibility +clean_bootstrap_env = clean_env @pytest.fixture @@ -94,68 +102,6 @@ def _make_oci_config_file( return _make_oci_config_file -@pytest.fixture -def clean_bootstrap_env(): - """Fixture to clean environment variables that affect bootstrap. - - This fixture saves current env vars, clears them for the test, - and restores them afterward. - """ - env_vars = [ - # Database vars - "DB_USERNAME", - "DB_PASSWORD", - "DB_DSN", - "DB_WALLET_PASSWORD", - "TNS_ADMIN", - # Model API keys - "OPENAI_API_KEY", - "COHERE_API_KEY", - "PPLX_API_KEY", - # On-prem model URLs - "ON_PREM_OLLAMA_URL", - "ON_PREM_VLLM_URL", - "ON_PREM_HF_URL", - # OCI vars - "OCI_CLI_CONFIG_FILE", - "OCI_CLI_TENANCY", - "OCI_CLI_REGION", - "OCI_CLI_USER", - "OCI_CLI_FINGERPRINT", - "OCI_CLI_KEY_FILE", - "OCI_CLI_SECURITY_TOKEN_FILE", - "OCI_CLI_AUTH", - "OCI_GENAI_COMPARTMENT_ID", - "OCI_GENAI_REGION", - "OCI_GENAI_SERVICE_ENDPOINT", - ] - - original_values = {} - for var in env_vars: - original_values[var] = os.environ.pop(var, None) - - yield - - # Restore original values - for var, value in original_values.items(): - if value is not None: - os.environ[var] = value - elif var in os.environ: - del os.environ[var] - - -@pytest.fixture -def reset_config_store(): - """Reset ConfigStore singleton state before and after each test.""" - # Reset before test - ConfigStore._config = None - - yield ConfigStore - - # Reset after test - ConfigStore._config = None - - @pytest.fixture def sample_database_config(): """Sample database configuration dict.""" @@ -197,10 +143,5 @@ def sample_settings_config(): """Sample settings configuration dict.""" return { "client": "integration_client", - "ll_model": { - "model": "gpt-4o-mini", - "temperature": 0.7, - "max_tokens": 4096, - "chat_history": True, - }, + "ll_model": DEFAULT_LL_MODEL_CONFIG.copy(), } diff --git a/test/integration/server/bootstrap/test_bootstrap_databases.py b/test/integration/server/bootstrap/test_bootstrap_databases.py index f46065b3..70336db5 100644 --- a/test/integration/server/bootstrap/test_bootstrap_databases.py +++ b/test/integration/server/bootstrap/test_bootstrap_databases.py @@ -12,10 +12,15 @@ import os +from test.shared_fixtures import ( + assert_database_list_valid, + assert_has_default_database, + get_database_by_name, +) + import pytest from server.bootstrap import databases as databases_module -from common.schema import Database @pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") @@ -25,16 +30,12 @@ class TestDatabasesBootstrapWithConfig: def test_bootstrap_returns_database_objects(self): """databases.main() should return list of Database objects.""" result = databases_module.main() - - assert isinstance(result, list) - assert all(isinstance(db, Database) for db in result) + assert_database_list_valid(result) def test_bootstrap_creates_default_database(self): """databases.main() should always create DEFAULT database.""" result = databases_module.main() - - db_names = [db.name for db in result] - assert "DEFAULT" in db_names + assert_has_default_database(result) def test_bootstrap_with_config_file_databases(self, reset_config_store, make_config_file): """databases.main() should load databases from config file.""" @@ -81,12 +82,10 @@ def test_bootstrap_default_from_config_overridden_by_env(self, reset_config_stor try: reset_config_store.load_from_file(config_path) result = databases_module.main() - - default_db = next(db for db in result if db.name == "DEFAULT") + default_db = get_database_by_name(result, "DEFAULT") assert default_db.user == "env_user" assert default_db.password == "env_password" - # DSN not in env, should keep config value - assert default_db.dsn == "config_host:1521/CFGPDB" + assert default_db.dsn == "config_host:1521/CFGPDB" # DSN not in env, keep config value finally: del os.environ["DB_USERNAME"] del os.environ["DB_PASSWORD"] @@ -118,8 +117,7 @@ def test_bootstrap_uses_env_vars_for_default(self): try: result = databases_module.main() - - default_db = next(db for db in result if db.name == "DEFAULT") + default_db = get_database_by_name(result, "DEFAULT") assert default_db.user == "env_user" assert default_db.password == "env_password" assert default_db.dsn == "env_host:1521/ENVPDB" @@ -135,8 +133,7 @@ def test_bootstrap_wallet_password_sets_wallet_location(self): try: result = databases_module.main() - - default_db = next(db for db in result if db.name == "DEFAULT") + default_db = get_database_by_name(result, "DEFAULT") assert default_db.wallet_password == "wallet_secret" assert default_db.wallet_location == "/path/to/wallet" assert default_db.config_dir == "/path/to/wallet" @@ -147,8 +144,7 @@ def test_bootstrap_wallet_password_sets_wallet_location(self): def test_bootstrap_tns_admin_default(self): """databases.main() should use 'tns_admin' as default config_dir.""" result = databases_module.main() - - default_db = next(db for db in result if db.name == "DEFAULT") + default_db = get_database_by_name(result, "DEFAULT") assert default_db.config_dir == "tns_admin" @@ -174,8 +170,7 @@ def test_bootstrap_preserves_non_default_databases(self, reset_config_store, mak try: reset_config_store.load_from_file(config_path) result = databases_module.main() - - custom_db = next(db for db in result if db.name == "CUSTOM_DB") + custom_db = get_database_by_name(result, "CUSTOM_DB") assert custom_db.user == "custom_user" assert custom_db.password == "custom_pass" finally: @@ -194,12 +189,9 @@ def test_bootstrap_creates_default_when_not_in_config(self, reset_config_store, try: reset_config_store.load_from_file(config_path) result = databases_module.main() - - db_names = [db.name for db in result] - assert "DEFAULT" in db_names - assert "OTHER_DB" in db_names - - default_db = next(db for db in result if db.name == "DEFAULT") + assert_has_default_database(result) + assert "OTHER_DB" in [d.name for d in result] + default_db = get_database_by_name(result, "DEFAULT") assert default_db.user == "env_default_user" finally: del os.environ["DB_USERNAME"] diff --git a/test/integration/server/bootstrap/test_bootstrap_models.py b/test/integration/server/bootstrap/test_bootstrap_models.py index 52eb5ccc..b39042ec 100644 --- a/test/integration/server/bootstrap/test_bootstrap_models.py +++ b/test/integration/server/bootstrap/test_bootstrap_models.py @@ -13,10 +13,11 @@ import os from unittest.mock import patch +from test.shared_fixtures import assert_model_list_valid, get_model_by_id + import pytest from server.bootstrap import models as models_module -from common.schema import Model @pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") @@ -28,9 +29,7 @@ def test_bootstrap_returns_model_objects(self): with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: mock_accessible.return_value = (True, "OK") result = models_module.main() - - assert isinstance(result, list) - assert all(isinstance(m, Model) for m in result) + assert_model_list_valid(result) def test_bootstrap_includes_base_models(self): """models.main() should include base model configurations.""" @@ -66,8 +65,7 @@ def test_bootstrap_enables_models_with_openai_key(self): with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: mock_accessible.return_value = (True, "OK") result = models_module.main() - - openai_model = next(m for m in result if m.id == "gpt-4o-mini") + openai_model = get_model_by_id(result, "gpt-4o-mini") assert openai_model.enabled is True assert openai_model.api_key == "test-openai-key" finally: @@ -81,8 +79,7 @@ def test_bootstrap_enables_models_with_cohere_key(self): with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: mock_accessible.return_value = (True, "OK") result = models_module.main() - - cohere_model = next(m for m in result if m.id == "command-r") + cohere_model = get_model_by_id(result, "command-r") assert cohere_model.enabled is True assert cohere_model.api_key == "test-cohere-key" finally: @@ -93,10 +90,8 @@ def test_bootstrap_disables_models_without_keys(self): with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: mock_accessible.return_value = (True, "OK") result = models_module.main() - - # Without OPENAI_API_KEY, the model should be disabled - openai_model = next(m for m in result if m.id == "gpt-4o-mini") - assert openai_model.enabled is False + openai_model = get_model_by_id(result, "gpt-4o-mini") + assert openai_model.enabled is False # Without OPENAI_API_KEY @pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") @@ -111,8 +106,7 @@ def test_bootstrap_enables_ollama_with_url(self): with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: mock_accessible.return_value = (True, "OK") result = models_module.main() - - ollama_model = next(m for m in result if m.id == "llama3.1") + ollama_model = get_model_by_id(result, "llama3.1") assert ollama_model.enabled is True assert ollama_model.api_base == "http://localhost:11434" finally: @@ -126,10 +120,8 @@ def test_bootstrap_checks_url_accessibility(self): with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: mock_accessible.return_value = (False, "Connection refused") result = models_module.main() - - ollama_model = next(m for m in result if m.id == "llama3.1") - # Should be disabled if URL is not accessible - assert ollama_model.enabled is False + ollama_model = get_model_by_id(result, "llama3.1") + assert ollama_model.enabled is False # Should be disabled if URL not accessible finally: del os.environ["ON_PREM_OLLAMA_URL"] @@ -163,7 +155,7 @@ def test_bootstrap_merges_config_store_models(self, reset_config_store, make_con model_ids = [m.id for m in result] assert "custom-model" in model_ids - custom_model = next(m for m in result if m.id == "custom-model") + custom_model = get_model_by_id(result, "custom-model") assert custom_model.provider == "custom" assert custom_model.api_base == "https://custom.api/v1" finally: @@ -191,8 +183,7 @@ def test_bootstrap_config_store_overrides_base_model(self, reset_config_store, m with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: mock_accessible.return_value = (True, "OK") result = models_module.main() - - openai_model = next(m for m in result if m.id == "gpt-4o-mini") + openai_model = get_model_by_id(result, "gpt-4o-mini") assert openai_model.api_key == "override-key" assert openai_model.max_tokens == 9999 finally: diff --git a/test/shared_fixtures.py b/test/shared_fixtures.py new file mode 100644 index 00000000..22953e39 --- /dev/null +++ b/test/shared_fixtures.py @@ -0,0 +1,330 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Shared pytest fixtures for unit and integration tests. + +This module contains common fixture factories and utilities that are shared +across multiple test conftest files to avoid code duplication. +""" + +# pylint: disable=redefined-outer-name + +import json +import os +import tempfile +from pathlib import Path + +import pytest + +from common.schema import ( + Configuration, + Database, + Model, + OracleCloudSettings, + Settings, + LargeLanguageSettings, +) +from server.bootstrap.configfile import ConfigStore + + +# Default test model settings - shared across test fixtures +DEFAULT_LL_MODEL_CONFIG = { + "model": "gpt-4o-mini", + "temperature": 0.7, + "max_tokens": 4096, + "chat_history": True, +} + +# Environment variables used by bootstrap modules +BOOTSTRAP_ENV_VARS = [ + # Database vars + "DB_USERNAME", + "DB_PASSWORD", + "DB_DSN", + "DB_WALLET_PASSWORD", + "TNS_ADMIN", + # Model API keys + "OPENAI_API_KEY", + "COHERE_API_KEY", + "PPLX_API_KEY", + # On-prem model URLs + "ON_PREM_OLLAMA_URL", + "ON_PREM_VLLM_URL", + "ON_PREM_HF_URL", + # OCI vars + "OCI_CLI_CONFIG_FILE", + "OCI_CLI_TENANCY", + "OCI_CLI_REGION", + "OCI_CLI_USER", + "OCI_CLI_FINGERPRINT", + "OCI_CLI_KEY_FILE", + "OCI_CLI_SECURITY_TOKEN_FILE", + "OCI_CLI_AUTH", + "OCI_GENAI_COMPARTMENT_ID", + "OCI_GENAI_REGION", + "OCI_GENAI_SERVICE_ENDPOINT", +] + + +################################################# +# Schema Factory Fixtures +################################################# + + +@pytest.fixture +def make_database(): + """Factory fixture to create Database objects.""" + + def _make_database( + name: str = "TEST_DB", + user: str = "test_user", + password: str = "test_password", + dsn: str = "localhost:1521/TESTPDB", + wallet_password: str = None, + **kwargs, + ) -> Database: + return Database( + name=name, + user=user, + password=password, + dsn=dsn, + wallet_password=wallet_password, + **kwargs, + ) + + return _make_database + + +@pytest.fixture +def make_model(): + """Factory fixture to create Model objects. + + Supports both `model_id` and `id` parameter names for backwards compatibility. + """ + + def _make_model( + model_id: str = None, + model_type: str = "ll", + provider: str = "openai", + enabled: bool = True, + api_key: str = "test-key", + api_base: str = "https://api.openai.com/v1", + **kwargs, + ) -> Model: + # Support both 'id' kwarg and 'model_id' parameter for backwards compat + resolved_id = kwargs.pop("id", None) or model_id or "gpt-4o-mini" + return Model( + id=resolved_id, + type=model_type, + provider=provider, + enabled=enabled, + api_key=api_key, + api_base=api_base, + **kwargs, + ) + + return _make_model + + +@pytest.fixture +def make_oci_config(): + """Factory fixture to create OracleCloudSettings objects. + + Note: The 'user' field requires OCID format pattern matching. + Use None to skip the user field in tests that don't need it. + """ + + def _make_oci_config( + auth_profile: str = "DEFAULT", + tenancy: str = "test-tenancy", + region: str = "us-ashburn-1", + user: str = None, # Use None by default - OCID pattern required + fingerprint: str = "test-fingerprint", + key_file: str = "/path/to/key", + **kwargs, + ) -> OracleCloudSettings: + return OracleCloudSettings( + auth_profile=auth_profile, + tenancy=tenancy, + region=region, + user=user, + fingerprint=fingerprint, + key_file=key_file, + **kwargs, + ) + + return _make_oci_config + + +@pytest.fixture +def make_ll_settings(): + """Factory fixture to create LargeLanguageSettings objects.""" + + def _make_ll_settings( + model: str = "gpt-4o-mini", + temperature: float = 0.7, + max_tokens: int = 4096, + chat_history: bool = True, + **kwargs, + ) -> LargeLanguageSettings: + return LargeLanguageSettings( + model=model, + temperature=temperature, + max_tokens=max_tokens, + chat_history=chat_history, + **kwargs, + ) + + return _make_ll_settings + + +@pytest.fixture +def make_settings(make_ll_settings): + """Factory fixture to create Settings objects.""" + + def _make_settings( + client: str = "test_client", + ll_model: LargeLanguageSettings = None, + **kwargs, + ) -> Settings: + if ll_model is None: + ll_model = make_ll_settings() + return Settings( + client=client, + ll_model=ll_model, + **kwargs, + ) + + return _make_settings + + +@pytest.fixture +def make_configuration(make_settings): + """Factory fixture to create Configuration objects.""" + + def _make_configuration( + client_settings: Settings = None, + database_configs: list = None, + model_configs: list = None, + oci_configs: list = None, + **kwargs, + ) -> Configuration: + return Configuration( + client_settings=client_settings or make_settings(), + database_configs=database_configs or [], + model_configs=model_configs or [], + oci_configs=oci_configs or [], + prompt_configs=[], + **kwargs, + ) + + return _make_configuration + + +################################################# +# Config File Fixtures +################################################# + + +@pytest.fixture +def temp_config_file(make_settings): + """Create a temporary configuration JSON file.""" + + def _create_temp_config( + client_settings: Settings = None, + database_configs: list = None, + model_configs: list = None, + oci_configs: list = None, + ): + config_data = { + "client_settings": (client_settings or make_settings()).model_dump(), + "database_configs": [ + (db if isinstance(db, dict) else db.model_dump()) + for db in (database_configs or []) + ], + "model_configs": [ + (m if isinstance(m, dict) else m.model_dump()) + for m in (model_configs or []) + ], + "oci_configs": [ + (o if isinstance(o, dict) else o.model_dump()) + for o in (oci_configs or []) + ], + "prompt_configs": [], + } + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False, encoding="utf-8" + ) as temp_file: + json.dump(config_data, temp_file) + return Path(temp_file.name) + + return _create_temp_config + + +@pytest.fixture +def reset_config_store(): + """Reset ConfigStore singleton state before and after each test.""" + # Reset before test + ConfigStore.reset() + + yield ConfigStore + + # Reset after test + ConfigStore.reset() + + +################################################# +# Test Helper Functions (shared assertions to reduce duplication) +################################################# + + +def assert_database_list_valid(result): + """Assert that result is a valid list of Database objects.""" + assert isinstance(result, list) + assert all(isinstance(db, Database) for db in result) + + +def assert_has_default_database(result): + """Assert that DEFAULT database is in the result.""" + db_names = [db.name for db in result] + assert "DEFAULT" in db_names + + +def get_database_by_name(result, name): + """Get a database from results by name.""" + return next(db for db in result if db.name == name) + + +def assert_model_list_valid(result): + """Assert that result is a valid list of Model objects.""" + assert isinstance(result, list) + assert all(isinstance(m, Model) for m in result) + + +def get_model_by_id(result, model_id): + """Get a model from results by id.""" + return next(m for m in result if m.id == model_id) + + +################################################# +# Environment Fixtures +################################################# + + +@pytest.fixture +def clean_env(): + """Fixture to temporarily clear relevant environment variables.""" + original_values = {} + for var in BOOTSTRAP_ENV_VARS: + original_values[var] = os.environ.pop(var, None) + + yield + + # Restore original values + for var, value in original_values.items(): + if value is not None: + os.environ[var] = value + elif var in os.environ: + del os.environ[var] diff --git a/test/unit/server/api/conftest.py b/test/unit/server/api/conftest.py index c1ba1493..8ec4542a 100644 --- a/test/unit/server/api/conftest.py +++ b/test/unit/server/api/conftest.py @@ -6,149 +6,43 @@ Provides factory fixtures for creating test objects. """ -# pylint: disable=redefined-outer-name +# pylint: disable=redefined-outer-name unused-import # Pytest fixtures use parameter injection where fixture names match parameters from unittest.mock import MagicMock, AsyncMock + +# Re-export shared fixtures for pytest discovery (before third-party imports per pylint) +from test.shared_fixtures import ( + make_database, + make_model, + make_oci_config, + make_ll_settings, + make_settings, + make_configuration, +) + import pytest from common.schema import ( - Database, DatabaseAuth, - Model, - OracleCloudSettings, - Settings, - LargeLanguageSettings, DatabaseVectorStorage, ChatRequest, - Configuration, ) -@pytest.fixture -def make_database(): - """Factory fixture to create Database objects.""" - - def _make_database( - name: str = "TEST_DB", - user: str = "test_user", - password: str = "test_password", - dsn: str = "localhost:1521/TESTPDB", - wallet_password: str = None, - **kwargs, - ) -> Database: - return Database( - name=name, - user=user, - password=password, - dsn=dsn, - wallet_password=wallet_password, - **kwargs, - ) - - return _make_database - - -@pytest.fixture -def make_model(): - """Factory fixture to create Model objects.""" - - def _make_model( - model_id: str = "gpt-4o-mini", - model_type: str = "ll", - provider: str = "openai", - enabled: bool = True, - **kwargs, - ) -> Model: - return Model( - id=model_id, - type=model_type, - provider=provider, - enabled=enabled, - **kwargs, - ) - - return _make_model - - -@pytest.fixture -def make_oci_config(): - """Factory fixture to create OracleCloudSettings objects.""" - - def _make_oci_config( - auth_profile: str = "DEFAULT", - genai_region: str = "us-ashburn-1", - **kwargs, - ) -> OracleCloudSettings: - return OracleCloudSettings( - auth_profile=auth_profile, - genai_region=genai_region, - **kwargs, - ) - - return _make_oci_config - - -@pytest.fixture -def make_ll_settings(): - """Factory fixture to create LargeLanguageSettings objects.""" - - def _make_ll_settings( - model: str = "gpt-4o-mini", - temperature: float = 0.7, - max_tokens: int = 4096, - chat_history: bool = True, - **kwargs, - ) -> LargeLanguageSettings: - return LargeLanguageSettings( - model=model, - temperature=temperature, - max_tokens=max_tokens, - chat_history=chat_history, - **kwargs, - ) - - return _make_ll_settings - - -@pytest.fixture -def make_settings(make_ll_settings): - """Factory fixture to create Settings objects.""" - - def _make_settings( - client: str = "test_client", - ll_model: LargeLanguageSettings = None, - **kwargs, - ) -> Settings: - if ll_model is None: - ll_model = make_ll_settings() - return Settings( - client=client, - ll_model=ll_model, - **kwargs, - ) - - return _make_settings - - @pytest.fixture def make_database_auth(): """Factory fixture to create DatabaseAuth objects.""" - def _make_database_auth( - user: str = "test_user", - password: str = "test_password", - dsn: str = "localhost:1521/TESTPDB", - wallet_password: str = None, - **kwargs, - ) -> DatabaseAuth: - return DatabaseAuth( - user=user, - password=password, - dsn=dsn, - wallet_password=wallet_password, - **kwargs, - ) + def _make_database_auth(**overrides) -> DatabaseAuth: + defaults = { + "user": "test_user", + "password": "test_password", + "dsn": "localhost:1521/TESTPDB", + "wallet_password": None, + } + defaults.update(overrides) + return DatabaseAuth(**defaults) return _make_database_auth @@ -215,29 +109,6 @@ def _make_mcp_prompt( return _make_mcp_prompt -@pytest.fixture -def make_configuration(make_settings): - """Factory fixture to create Configuration objects.""" - - def _make_configuration( - client: str = "test_client", - client_settings: Settings = None, - **kwargs, - ) -> Configuration: - if client_settings is None: - client_settings = make_settings(client=client) - return Configuration( - client_settings=client_settings, - database_configs=[], - model_configs=[], - oci_configs=[], - prompt_configs=[], - **kwargs, - ) - - return _make_configuration - - @pytest.fixture def mock_fastmcp(): """Create a mock FastMCP application.""" diff --git a/test/unit/server/bootstrap/conftest.py b/test/unit/server/bootstrap/conftest.py index aff230bb..9bc23743 100644 --- a/test/unit/server/bootstrap/conftest.py +++ b/test/unit/server/bootstrap/conftest.py @@ -3,220 +3,33 @@ Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. Pytest fixtures for server/bootstrap unit tests. + +Re-exports shared fixtures from test.shared_fixtures and adds unit-test specific fixtures. """ -# pylint: disable=redefined-outer-name protected-access too-few-public-methods +# pylint: disable=redefined-outer-name unused-import -import json -import os -import tempfile -from pathlib import Path from unittest.mock import MagicMock, patch -import pytest - -from common.schema import ( - Configuration, - Database, - Model, - OracleCloudSettings, - Settings, - LargeLanguageSettings, +# Re-export shared fixtures for pytest discovery +from test.shared_fixtures import ( + make_database, + make_model, + make_oci_config, + make_ll_settings, + make_settings, + make_configuration, + temp_config_file, + reset_config_store, + clean_env, ) -from server.bootstrap.configfile import ConfigStore - - -@pytest.fixture -def make_database(): - """Factory fixture to create Database objects.""" - - def _make_database( - name: str = "TEST_DB", - user: str = "test_user", - password: str = "test_password", - dsn: str = "localhost:1521/TESTPDB", - wallet_password: str = None, - **kwargs, - ) -> Database: - return Database( - name=name, - user=user, - password=password, - dsn=dsn, - wallet_password=wallet_password, - **kwargs, - ) - - return _make_database - - -@pytest.fixture -def make_model(): - """Factory fixture to create Model objects.""" - - def _make_model( - model_id: str = "gpt-4o-mini", - model_type: str = "ll", - provider: str = "openai", - enabled: bool = True, - api_key: str = "test-key", - api_base: str = "https://api.openai.com/v1", - **kwargs, - ) -> Model: - return Model( - id=model_id, - type=model_type, - provider=provider, - enabled=enabled, - api_key=api_key, - api_base=api_base, - **kwargs, - ) - - return _make_model - - -@pytest.fixture -def make_oci_config(): - """Factory fixture to create OracleCloudSettings objects. - - Note: The 'user' field requires OCID format pattern matching. - Use None to skip the user field in tests that don't need it. - """ - - def _make_oci_config( - auth_profile: str = "DEFAULT", - tenancy: str = "test-tenancy", - region: str = "us-ashburn-1", - user: str = None, # Use None by default - OCID pattern required - fingerprint: str = "test-fingerprint", - key_file: str = "/path/to/key", - **kwargs, - ) -> OracleCloudSettings: - return OracleCloudSettings( - auth_profile=auth_profile, - tenancy=tenancy, - region=region, - user=user, - fingerprint=fingerprint, - key_file=key_file, - **kwargs, - ) - - return _make_oci_config - - -@pytest.fixture -def make_ll_settings(): - """Factory fixture to create LargeLanguageSettings objects.""" - - def _make_ll_settings( - model: str = "gpt-4o-mini", - temperature: float = 0.7, - max_tokens: int = 4096, - chat_history: bool = True, - **kwargs, - ) -> LargeLanguageSettings: - return LargeLanguageSettings( - model=model, - temperature=temperature, - max_tokens=max_tokens, - chat_history=chat_history, - **kwargs, - ) - - return _make_ll_settings - -@pytest.fixture -def make_settings(make_ll_settings): - """Factory fixture to create Settings objects.""" - - def _make_settings( - client: str = "test_client", - ll_model: LargeLanguageSettings = None, - **kwargs, - ) -> Settings: - if ll_model is None: - ll_model = make_ll_settings() - return Settings( - client=client, - ll_model=ll_model, - **kwargs, - ) - - return _make_settings - - -@pytest.fixture -def make_configuration(make_settings): - """Factory fixture to create Configuration objects.""" - - def _make_configuration( - client_settings: Settings = None, - database_configs: list = None, - model_configs: list = None, - oci_configs: list = None, - **kwargs, - ) -> Configuration: - return Configuration( - client_settings=client_settings or make_settings(), - database_configs=database_configs or [], - model_configs=model_configs or [], - oci_configs=oci_configs or [], - prompt_configs=[], - **kwargs, - ) - - return _make_configuration - - -@pytest.fixture -def temp_config_file(make_settings): - """Create a temporary configuration JSON file.""" - - def _create_temp_config( - client_settings: Settings = None, - database_configs: list = None, - model_configs: list = None, - oci_configs: list = None, - ): - config_data = { - "client_settings": (client_settings or make_settings()).model_dump(), - "database_configs": [ - (db if isinstance(db, dict) else db.model_dump()) - for db in (database_configs or []) - ], - "model_configs": [ - (m if isinstance(m, dict) else m.model_dump()) - for m in (model_configs or []) - ], - "oci_configs": [ - (o if isinstance(o, dict) else o.model_dump()) - for o in (oci_configs or []) - ], - "prompt_configs": [], - } - - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False, encoding="utf-8" - ) as temp_file: - json.dump(config_data, temp_file) - return Path(temp_file.name) - - return _create_temp_config - - -@pytest.fixture -def reset_config_store(): - """Reset ConfigStore singleton state before and after each test.""" - # Reset before test - ConfigStore._config = None +import pytest - yield ConfigStore - # Reset after test - ConfigStore._config = None +################################################# +# Unit Test Specific Mock Fixtures +################################################# @pytest.fixture @@ -242,45 +55,3 @@ def mock_is_url_accessible(): with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: mock_accessible.return_value = (True, "OK") yield mock_accessible - - -@pytest.fixture -def clean_env(): - """Fixture to temporarily clear relevant environment variables.""" - env_vars = [ - "DB_USERNAME", - "DB_PASSWORD", - "DB_DSN", - "DB_WALLET_PASSWORD", - "TNS_ADMIN", - "OPENAI_API_KEY", - "COHERE_API_KEY", - "PPLX_API_KEY", - "ON_PREM_OLLAMA_URL", - "ON_PREM_VLLM_URL", - "ON_PREM_HF_URL", - "OCI_CLI_CONFIG_FILE", - "OCI_CLI_TENANCY", - "OCI_CLI_REGION", - "OCI_CLI_USER", - "OCI_CLI_FINGERPRINT", - "OCI_CLI_KEY_FILE", - "OCI_CLI_SECURITY_TOKEN_FILE", - "OCI_CLI_AUTH", - "OCI_GENAI_COMPARTMENT_ID", - "OCI_GENAI_REGION", - "OCI_GENAI_SERVICE_ENDPOINT", - ] - - original_values = {} - for var in env_vars: - original_values[var] = os.environ.pop(var, None) - - yield - - # Restore original values - for var, value in original_values.items(): - if value is not None: - os.environ[var] = value - elif var in os.environ: - del os.environ[var] diff --git a/test/unit/server/bootstrap/test_bootstrap_databases.py b/test/unit/server/bootstrap/test_bootstrap_databases.py index 5ab5afb6..3a658689 100644 --- a/test/unit/server/bootstrap/test_bootstrap_databases.py +++ b/test/unit/server/bootstrap/test_bootstrap_databases.py @@ -10,10 +10,15 @@ import os +from test.shared_fixtures import ( + assert_database_list_valid, + assert_has_default_database, + get_database_by_name, +) + import pytest from server.bootstrap import databases as databases_module -from common.schema import Database @pytest.mark.usefixtures("reset_config_store", "clean_env") @@ -23,16 +28,12 @@ class TestDatabasesMain: def test_main_returns_list_of_databases(self): """main() should return a list of Database objects.""" result = databases_module.main() - - assert isinstance(result, list) - assert all(isinstance(db, Database) for db in result) + assert_database_list_valid(result) def test_main_creates_default_database_when_no_config(self): """main() should create DEFAULT database when no config is loaded.""" result = databases_module.main() - - db_names = [db.name for db in result] - assert "DEFAULT" in db_names + assert_has_default_database(result) def test_main_uses_env_vars_for_default_database(self): """main() should use environment variables for DEFAULT database.""" @@ -42,13 +43,12 @@ def test_main_uses_env_vars_for_default_database(self): os.environ["TNS_ADMIN"] = "/env/tns_admin" try: - result = databases_module.main() - - default_db = next(db for db in result if db.name == "DEFAULT") - assert default_db.user == "env_user" - assert default_db.password == "env_password" - assert default_db.dsn == "env_dsn:1521/ENVPDB" - assert default_db.config_dir == "/env/tns_admin" + db_list = databases_module.main() + default_entry = get_database_by_name(db_list, "DEFAULT") + assert default_entry.user == "env_user" + assert default_entry.password == "env_password" + assert default_entry.dsn == "env_dsn:1521/ENVPDB" + assert default_entry.config_dir == "/env/tns_admin" finally: del os.environ["DB_USERNAME"] del os.environ["DB_PASSWORD"] @@ -62,8 +62,7 @@ def test_main_sets_wallet_location_when_wallet_password_present(self): try: result = databases_module.main() - - default_db = next(db for db in result if db.name == "DEFAULT") + default_db = get_database_by_name(result, "DEFAULT") assert default_db.wallet_password == "wallet_pass" assert default_db.wallet_location == "/wallet/path" finally: @@ -81,9 +80,9 @@ def test_main_with_config_file_databases( try: reset_config_store.load_from_file(config_path) - result = databases_module.main() + integration_result = databases_module.main() - db_names = [db.name for db in result] + db_names = [db.name for db in integration_result] assert "CONFIG_DB1" in db_names assert "CONFIG_DB2" in db_names finally: @@ -93,24 +92,22 @@ def test_main_overrides_default_from_config_with_env_vars( self, reset_config_store, temp_config_file, make_settings, make_database ): """main() should override DEFAULT database from config with env vars.""" - settings = make_settings() - db = make_database(name="DEFAULT", user="config_user", password="config_pass", dsn="config_dsn") - config_path = temp_config_file(client_settings=settings, database_configs=[db]) + test_settings = make_settings() + test_db = make_database(name="DEFAULT", user="config_user", password="config_pass", dsn="config_dsn") + cfg_path = temp_config_file(client_settings=test_settings, database_configs=[test_db]) os.environ["DB_USERNAME"] = "env_user" os.environ["DB_PASSWORD"] = "env_password" try: - reset_config_store.load_from_file(config_path) - result = databases_module.main() - - default_db = next(db for db in result if db.name == "DEFAULT") - assert default_db.user == "env_user" - assert default_db.password == "env_password" - # DSN not in env, should keep config value - assert default_db.dsn == "config_dsn" + reset_config_store.load_from_file(cfg_path) + db_list = databases_module.main() + default_entry = get_database_by_name(db_list, "DEFAULT") + assert default_entry.user == "env_user" + assert default_entry.password == "env_password" + assert default_entry.dsn == "config_dsn" # DSN not in env, keep config value finally: - os.unlink(config_path) + os.unlink(cfg_path) del os.environ["DB_USERNAME"] del os.environ["DB_PASSWORD"] @@ -135,24 +132,21 @@ def test_main_creates_default_when_not_in_config( self, reset_config_store, temp_config_file, make_settings, make_database ): """main() should create DEFAULT database from env when not in config.""" - settings = make_settings() - db = make_database(name="OTHER_DB", user="other_user") - config_path = temp_config_file(client_settings=settings, database_configs=[db]) + test_settings = make_settings() + other_db = make_database(name="OTHER_DB", user="other_user") + cfg_path = temp_config_file(client_settings=test_settings, database_configs=[other_db]) os.environ["DB_USERNAME"] = "default_env_user" try: - reset_config_store.load_from_file(config_path) - result = databases_module.main() - - db_names = [db.name for db in result] - assert "DEFAULT" in db_names - assert "OTHER_DB" in db_names - - default_db = next(db for db in result if db.name == "DEFAULT") - assert default_db.user == "default_env_user" + reset_config_store.load_from_file(cfg_path) + db_list = databases_module.main() + assert_has_default_database(db_list) + assert "OTHER_DB" in [d.name for d in db_list] + default_entry = get_database_by_name(db_list, "DEFAULT") + assert default_entry.user == "default_env_user" finally: - os.unlink(config_path) + os.unlink(cfg_path) del os.environ["DB_USERNAME"] def test_main_handles_case_insensitive_default_name( @@ -180,28 +174,26 @@ def test_main_preserves_non_default_databases_unchanged( self, reset_config_store, temp_config_file, make_settings, make_database ): """main() should not modify non-DEFAULT databases.""" - settings = make_settings() - db = make_database(name="CUSTOM_DB", user="custom_user", password="custom_pass") - config_path = temp_config_file(client_settings=settings, database_configs=[db]) + test_settings = make_settings() + custom_db_config = make_database(name="CUSTOM_DB", user="custom_user", password="custom_pass") + cfg_path = temp_config_file(client_settings=test_settings, database_configs=[custom_db_config]) os.environ["DB_USERNAME"] = "should_not_apply" try: - reset_config_store.load_from_file(config_path) - result = databases_module.main() - - custom_db = next(db for db in result if db.name == "CUSTOM_DB") - assert custom_db.user == "custom_user" - assert custom_db.password == "custom_pass" + reset_config_store.load_from_file(cfg_path) + db_list = databases_module.main() + custom_entry = get_database_by_name(db_list, "CUSTOM_DB") + assert custom_entry.user == "custom_user" + assert custom_entry.password == "custom_pass" finally: - os.unlink(config_path) + os.unlink(cfg_path) del os.environ["DB_USERNAME"] def test_main_default_config_dir_fallback(self): """main() should use 'tns_admin' as default config_dir when not specified.""" result = databases_module.main() - - default_db = next(db for db in result if db.name == "DEFAULT") + default_db = get_database_by_name(result, "DEFAULT") assert default_db.config_dir == "tns_admin" diff --git a/test/unit/server/bootstrap/test_bootstrap_models.py b/test/unit/server/bootstrap/test_bootstrap_models.py index 27728c5d..8f8a8b1d 100644 --- a/test/unit/server/bootstrap/test_bootstrap_models.py +++ b/test/unit/server/bootstrap/test_bootstrap_models.py @@ -11,10 +11,11 @@ import os from unittest.mock import patch +from test.shared_fixtures import assert_model_list_valid, get_model_by_id + import pytest from server.bootstrap import models as models_module -from common.schema import Model @pytest.mark.usefixtures("reset_config_store", "clean_env", "mock_is_url_accessible") @@ -24,9 +25,7 @@ class TestModelsMain: def test_main_returns_list_of_models(self): """main() should return a list of Model objects.""" result = models_module.main() - - assert isinstance(result, list) - assert all(isinstance(m, Model) for m in result) + assert_model_list_valid(result) def test_main_includes_base_models(self): """main() should include base model configurations.""" @@ -42,20 +41,18 @@ def test_main_enables_models_with_api_keys(self): os.environ["OPENAI_API_KEY"] = "test-openai-key" try: - result = models_module.main() - - openai_model = next(m for m in result if m.id == "gpt-4o-mini") - assert openai_model.enabled is True - assert openai_model.api_key == "test-openai-key" + model_list = models_module.main() + gpt_model = get_model_by_id(model_list, "gpt-4o-mini") + assert gpt_model.enabled is True + assert gpt_model.api_key == "test-openai-key" finally: del os.environ["OPENAI_API_KEY"] def test_main_disables_models_without_api_keys(self): """main() should disable models when API keys are not present.""" - result = models_module.main() - - openai_model = next(m for m in result if m.id == "gpt-4o-mini") - assert openai_model.enabled is False + model_list = models_module.main() + gpt_model = get_model_by_id(model_list, "gpt-4o-mini") + assert gpt_model.enabled is False @pytest.mark.usefixtures("reset_config_store", "clean_env") def test_main_checks_url_accessibility(self): @@ -67,10 +64,8 @@ def test_main_checks_url_accessibility(self): try: result = models_module.main() - - # Model should be disabled if URL is not accessible - openai_model = next(m for m in result if m.id == "gpt-4o-mini") - assert openai_model.enabled is False + openai_model = get_model_by_id(result, "gpt-4o-mini") + assert openai_model.enabled is False # Model disabled if URL not accessible mock_accessible.assert_called() finally: del os.environ["OPENAI_API_KEY"] From c40eb10b92060ac0fc284eab9dc0898043d62d8c Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 19:04:29 +0000 Subject: [PATCH 10/20] Add tests for latest merge --- .../api/utils/test_utils_testbed_metrics.py | 345 ++++++++++++++++++ test/unit/server/api/v1/test_v1_testbed.py | 342 ++++++++++++++++- 2 files changed, 675 insertions(+), 12 deletions(-) create mode 100644 test/unit/server/api/utils/test_utils_testbed_metrics.py diff --git a/test/unit/server/api/utils/test_utils_testbed_metrics.py b/test/unit/server/api/utils/test_utils_testbed_metrics.py new file mode 100644 index 00000000..4431f4e5 --- /dev/null +++ b/test/unit/server/api/utils/test_utils_testbed_metrics.py @@ -0,0 +1,345 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for server/api/utils/testbed_metrics.py +Tests for custom testbed evaluation metrics. +""" + +# pylint: disable=too-few-public-methods,protected-access + +from unittest.mock import patch, MagicMock + +import pytest + +from giskard.llm.errors import LLMGenerationError + +from server.api.utils import testbed_metrics + + +class TestFormatConversation: + """Tests for the format_conversation function.""" + + def test_format_conversation_single_message(self): + """Should format single message correctly.""" + conversation = [{"role": "user", "content": "Hello"}] + + result = testbed_metrics.format_conversation(conversation) + + assert result == "Hello" + + def test_format_conversation_multiple_messages(self): + """Should format multiple messages with double newlines.""" + conversation = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + ] + + result = testbed_metrics.format_conversation(conversation) + + assert "Hello" in result + assert "Hi there" in result + assert "\n\n" in result + + def test_format_conversation_lowercases_role(self): + """Should lowercase role names in tags.""" + conversation = [{"role": "USER", "content": "Test"}] + + result = testbed_metrics.format_conversation(conversation) + + assert result == "Test" + + def test_format_conversation_empty_list(self): + """Should return empty string for empty conversation.""" + result = testbed_metrics.format_conversation([]) + + assert result == "" + + def test_format_conversation_preserves_content(self): + """Should preserve message content including special characters.""" + conversation = [{"role": "user", "content": "What is 2 + 2?\nIs it 4?"}] + + result = testbed_metrics.format_conversation(conversation) + + assert "What is 2 + 2?\nIs it 4?" in result + + +class TestCorrectnessInputTemplate: + """Tests for the CORRECTNESS_INPUT_TEMPLATE constant.""" + + def test_template_contains_placeholders(self): + """Template should contain all required placeholders.""" + template = testbed_metrics.CORRECTNESS_INPUT_TEMPLATE + + assert "{description}" in template + assert "{conversation}" in template + assert "{answer}" in template + assert "{reference_answer}" in template + + def test_template_format_works(self): + """Template should be formattable with all placeholders.""" + result = testbed_metrics.CORRECTNESS_INPUT_TEMPLATE.format( + description="Test agent", + conversation="Hello", + answer="Hi there", + reference_answer="Hello back", + ) + + assert "Test agent" in result + assert "Hello" in result + assert "Hi there" in result + assert "Hello back" in result + + +class TestCustomCorrectnessMetricInit: + """Tests for CustomCorrectnessMetric initialization.""" + + def test_init_with_required_params(self): + """Should initialize with required parameters.""" + metric = testbed_metrics.CustomCorrectnessMetric( + name="correctness", + system_prompt="You are a judge.", + ) + + assert metric.system_prompt == "You are a judge." + assert metric.agent_description == "A chatbot answering questions." + + def test_init_with_custom_agent_description(self): + """Should accept custom agent description.""" + metric = testbed_metrics.CustomCorrectnessMetric( + name="correctness", + system_prompt="You are a judge.", + agent_description="A specialized Q&A bot.", + ) + + assert metric.agent_description == "A specialized Q&A bot." + + def test_init_with_llm_client(self): + """Should accept custom LLM client.""" + mock_client = MagicMock() + + metric = testbed_metrics.CustomCorrectnessMetric( + name="correctness", + system_prompt="You are a judge.", + llm_client=mock_client, + ) + + assert metric._llm_client == mock_client + + +class TestCustomCorrectnessMetricCall: + """Tests for CustomCorrectnessMetric __call__ method.""" + + @patch("server.api.utils.testbed_metrics.get_default_client") + @patch("server.api.utils.testbed_metrics.parse_json_output") + def test_call_returns_correctness_result(self, mock_parse, mock_get_client): + """Should return correctness evaluation result.""" + mock_client = MagicMock() + mock_client.complete.return_value = MagicMock(content='{"correctness": true}') + mock_get_client.return_value = mock_client + mock_parse.return_value = {"correctness": True} + + metric = testbed_metrics.CustomCorrectnessMetric( + name="correctness", + system_prompt="You are a judge.", + ) + + mock_sample = MagicMock() + mock_sample.conversation_history = [] + mock_sample.question = "What is AI?" + mock_sample.reference_answer = "Artificial Intelligence" + + mock_answer = MagicMock() + mock_answer.message = "AI stands for Artificial Intelligence" + + result = metric(mock_sample, mock_answer) + + assert result == {"correctness": True} + mock_client.complete.assert_called_once() + + @patch("server.api.utils.testbed_metrics.get_default_client") + @patch("server.api.utils.testbed_metrics.parse_json_output") + def test_call_strips_reason_when_correct(self, mock_parse, mock_get_client): + """Should strip correctness_reason when answer is correct.""" + mock_client = MagicMock() + mock_client.complete.return_value = MagicMock(content='{}') + mock_get_client.return_value = mock_client + mock_parse.return_value = {"correctness": True, "correctness_reason": "Matches exactly"} + + metric = testbed_metrics.CustomCorrectnessMetric( + name="correctness", + system_prompt="You are a judge.", + ) + + mock_sample = MagicMock() + mock_sample.conversation_history = [] + mock_sample.question = "Q" + mock_sample.reference_answer = "A" + + mock_answer = MagicMock() + mock_answer.message = "A" + + result = metric(mock_sample, mock_answer) + + assert "correctness_reason" not in result + + @patch("server.api.utils.testbed_metrics.get_default_client") + @patch("server.api.utils.testbed_metrics.parse_json_output") + def test_call_keeps_reason_when_incorrect(self, mock_parse, mock_get_client): + """Should keep correctness_reason when answer is incorrect.""" + mock_client = MagicMock() + mock_client.complete.return_value = MagicMock(content='{}') + mock_get_client.return_value = mock_client + mock_parse.return_value = {"correctness": False, "correctness_reason": "Does not match"} + + metric = testbed_metrics.CustomCorrectnessMetric( + name="correctness", + system_prompt="You are a judge.", + ) + + mock_sample = MagicMock() + mock_sample.conversation_history = [] + mock_sample.question = "Q" + mock_sample.reference_answer = "A" + + mock_answer = MagicMock() + mock_answer.message = "Wrong" + + result = metric(mock_sample, mock_answer) + + assert result["correctness_reason"] == "Does not match" + + @patch("server.api.utils.testbed_metrics.get_default_client") + @patch("server.api.utils.testbed_metrics.parse_json_output") + def test_call_raises_on_non_boolean_correctness(self, mock_parse, mock_get_client): + """Should raise LLMGenerationError if correctness is not boolean.""" + mock_client = MagicMock() + mock_client.complete.return_value = MagicMock(content='{}') + mock_get_client.return_value = mock_client + mock_parse.return_value = {"correctness": "yes"} # String instead of bool + + metric = testbed_metrics.CustomCorrectnessMetric( + name="correctness", + system_prompt="You are a judge.", + ) + + mock_sample = MagicMock() + mock_sample.conversation_history = [] + mock_sample.question = "Q" + mock_sample.reference_answer = "A" + + mock_answer = MagicMock() + mock_answer.message = "A" + + with pytest.raises(LLMGenerationError) as exc_info: + metric(mock_sample, mock_answer) + + assert "Expected boolean" in str(exc_info.value) + + @patch("server.api.utils.testbed_metrics.get_default_client") + def test_call_reraises_llm_generation_error(self, mock_get_client): + """Should re-raise LLMGenerationError from LLM client.""" + mock_client = MagicMock() + mock_client.complete.side_effect = LLMGenerationError("LLM failed") + mock_get_client.return_value = mock_client + + metric = testbed_metrics.CustomCorrectnessMetric( + name="correctness", + system_prompt="You are a judge.", + ) + + mock_sample = MagicMock() + mock_sample.conversation_history = [] + mock_sample.question = "Q" + mock_sample.reference_answer = "A" + + mock_answer = MagicMock() + mock_answer.message = "A" + + with pytest.raises(LLMGenerationError): + metric(mock_sample, mock_answer) + + @patch("server.api.utils.testbed_metrics.get_default_client") + def test_call_wraps_other_exceptions(self, mock_get_client): + """Should wrap other exceptions in LLMGenerationError.""" + mock_client = MagicMock() + mock_client.complete.side_effect = RuntimeError("Unexpected error") + mock_get_client.return_value = mock_client + + metric = testbed_metrics.CustomCorrectnessMetric( + name="correctness", + system_prompt="You are a judge.", + ) + + mock_sample = MagicMock() + mock_sample.conversation_history = [] + mock_sample.question = "Q" + mock_sample.reference_answer = "A" + + mock_answer = MagicMock() + mock_answer.message = "A" + + with pytest.raises(LLMGenerationError) as exc_info: + metric(mock_sample, mock_answer) + + assert "Error while evaluating" in str(exc_info.value) + + @patch("server.api.utils.testbed_metrics.get_default_client") + @patch("server.api.utils.testbed_metrics.parse_json_output") + def test_call_uses_provided_llm_client(self, mock_parse, mock_get_client): + """Should use provided LLM client instead of default.""" + mock_provided_client = MagicMock() + mock_provided_client.complete.return_value = MagicMock(content='{}') + mock_parse.return_value = {"correctness": True} + + metric = testbed_metrics.CustomCorrectnessMetric( + name="correctness", + system_prompt="You are a judge.", + llm_client=mock_provided_client, + ) + + mock_sample = MagicMock() + mock_sample.conversation_history = [] + mock_sample.question = "Q" + mock_sample.reference_answer = "A" + + mock_answer = MagicMock() + mock_answer.message = "A" + + metric(mock_sample, mock_answer) + + mock_provided_client.complete.assert_called_once() + mock_get_client.assert_not_called() + + @patch("server.api.utils.testbed_metrics.get_default_client") + @patch("server.api.utils.testbed_metrics.parse_json_output") + def test_call_includes_conversation_history(self, mock_parse, mock_get_client): + """Should include conversation history in the prompt.""" + mock_client = MagicMock() + mock_client.complete.return_value = MagicMock(content='{}') + mock_get_client.return_value = mock_client + mock_parse.return_value = {"correctness": True} + + metric = testbed_metrics.CustomCorrectnessMetric( + name="correctness", + system_prompt="You are a judge.", + ) + + mock_sample = MagicMock() + mock_sample.conversation_history = [ + {"role": "user", "content": "Previous question"}, + {"role": "assistant", "content": "Previous answer"}, + ] + mock_sample.question = "Follow-up question" + mock_sample.reference_answer = "Expected answer" + + mock_answer = MagicMock() + mock_answer.message = "Actual answer" + + metric(mock_sample, mock_answer) + + call_args = mock_client.complete.call_args + user_message = call_args.kwargs["messages"][1].content + assert "Previous question" in user_message + assert "Previous answer" in user_message + assert "Follow-up question" in user_message diff --git a/test/unit/server/api/v1/test_v1_testbed.py b/test/unit/server/api/v1/test_v1_testbed.py index 14edd3dd..8ba6d12c 100644 --- a/test/unit/server/api/v1/test_v1_testbed.py +++ b/test/unit/server/api/v1/test_v1_testbed.py @@ -5,9 +5,10 @@ Unit tests for server/api/v1/testbed.py Tests for Q&A testbed and evaluation endpoints. """ -# pylint: disable=protected-access,too-few-public-methods +# pylint: disable=protected-access,too-few-public-methods,too-many-arguments +# pylint: disable=too-many-positional-arguments,too-many-locals -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, AsyncMock from io import BytesIO import pytest from fastapi import HTTPException, UploadFile @@ -23,7 +24,9 @@ class TestTestbedTestsets: @pytest.mark.asyncio @patch("server.api.v1.testbed.utils_databases.get_client_database") @patch("server.api.v1.testbed.utils_testbed.get_testsets") - async def test_testbed_testsets_returns_list(self, mock_get_testsets, mock_get_db, mock_db_connection): + async def test_testbed_testsets_returns_list( + self, mock_get_testsets, mock_get_db, mock_db_connection + ): """testbed_testsets should return list of testsets.""" mock_db = MagicMock() mock_db.connection = mock_db_connection @@ -43,7 +46,9 @@ async def test_testbed_testsets_returns_list(self, mock_get_testsets, mock_get_d @pytest.mark.asyncio @patch("server.api.v1.testbed.utils_databases.get_client_database") @patch("server.api.v1.testbed.utils_testbed.get_testsets") - async def test_testbed_testsets_empty_list(self, mock_get_testsets, mock_get_db, mock_db_connection): + async def test_testbed_testsets_empty_list( + self, mock_get_testsets, mock_get_db, mock_db_connection + ): """testbed_testsets should return empty list when no testsets.""" mock_db = MagicMock() mock_db.connection = mock_db_connection @@ -61,7 +66,9 @@ class TestTestbedEvaluations: @pytest.mark.asyncio @patch("server.api.v1.testbed.utils_databases.get_client_database") @patch("server.api.v1.testbed.utils_testbed.get_evaluations") - async def test_testbed_evaluations_returns_list(self, mock_get_evals, mock_get_db, mock_db_connection): + async def test_testbed_evaluations_returns_list( + self, mock_get_evals, mock_get_db, mock_db_connection + ): """testbed_evaluations should return list of evaluations.""" mock_db = MagicMock() mock_db.connection = mock_db_connection @@ -81,7 +88,9 @@ async def test_testbed_evaluations_returns_list(self, mock_get_evals, mock_get_d @pytest.mark.asyncio @patch("server.api.v1.testbed.utils_databases.get_client_database") @patch("server.api.v1.testbed.utils_testbed.get_evaluations") - async def test_testbed_evaluations_uppercases_tid(self, mock_get_evals, mock_get_db, mock_db_connection): + async def test_testbed_evaluations_uppercases_tid( + self, mock_get_evals, mock_get_db, mock_db_connection + ): """testbed_evaluations should uppercase the tid.""" mock_db = MagicMock() mock_db.connection = mock_db_connection @@ -99,7 +108,9 @@ class TestTestbedEvaluation: @pytest.mark.asyncio @patch("server.api.v1.testbed.utils_databases.get_client_database") @patch("server.api.v1.testbed.utils_testbed.process_report") - async def test_testbed_evaluation_returns_report(self, mock_process_report, mock_get_db, mock_db_connection): + async def test_testbed_evaluation_returns_report( + self, mock_process_report, mock_get_db, mock_db_connection + ): """testbed_evaluation should return evaluation report.""" mock_db = MagicMock() mock_db.connection = mock_db_connection @@ -120,7 +131,9 @@ class TestTestbedTestsetQa: @pytest.mark.asyncio @patch("server.api.v1.testbed.utils_databases.get_client_database") @patch("server.api.v1.testbed.utils_testbed.get_testset_qa") - async def test_testbed_testset_qa_returns_data(self, mock_get_qa, mock_get_db, mock_db_connection): + async def test_testbed_testset_qa_returns_data( + self, mock_get_qa, mock_get_db, mock_db_connection + ): """testbed_testset_qa should return Q&A data.""" mock_db = MagicMock() mock_db.connection = mock_db_connection @@ -141,7 +154,9 @@ class TestTestbedDeleteTestset: @pytest.mark.asyncio @patch("server.api.v1.testbed.utils_databases.get_client_database") @patch("server.api.v1.testbed.utils_testbed.delete_qa") - async def test_testbed_delete_testset_success(self, mock_delete_qa, mock_get_db, mock_db_connection): + async def test_testbed_delete_testset_success( + self, mock_delete_qa, mock_get_db, mock_db_connection + ): """testbed_delete_testset should delete and return success.""" mock_db = MagicMock() mock_db.connection = mock_db_connection @@ -185,7 +200,9 @@ async def test_testbed_upsert_testsets_success( @pytest.mark.asyncio @patch("server.api.v1.testbed.utils_databases.get_client_database") @patch("server.api.v1.testbed.utils_testbed.jsonl_to_json_content") - async def test_testbed_upsert_testsets_handles_exception(self, mock_jsonl, mock_get_db, mock_db_connection): + async def test_testbed_upsert_testsets_handles_exception( + self, mock_jsonl, mock_get_db, mock_db_connection + ): """testbed_upsert_testsets should raise 500 on exception.""" mock_db = MagicMock() mock_db.connection = mock_db_connection @@ -195,7 +212,9 @@ async def test_testbed_upsert_testsets_handles_exception(self, mock_jsonl, mock_ mock_file = UploadFile(file=BytesIO(b"invalid"), filename="test.jsonl") with pytest.raises(HTTPException) as exc_info: - await testbed.testbed_upsert_testsets(files=[mock_file], name="Test", tid=None, client="test_client") + await testbed.testbed_upsert_testsets( + files=[mock_file], name="Test", tid=None, client="test_client" + ) assert exc_info.value.status_code == 500 @@ -224,7 +243,9 @@ def test_handle_testset_error_value_error(self, tmp_path): def test_handle_testset_error_api_connection_error(self, tmp_path): """_handle_testset_error should raise 424 for API connection error.""" - ex = litellm.APIConnectionError(message="Connection failed", llm_provider="openai", model="gpt-4") + ex = litellm.APIConnectionError( + message="Connection failed", llm_provider="openai", model="gpt-4" + ) with pytest.raises(HTTPException) as exc_info: testbed._handle_testset_error(ex, tmp_path, "test-model") @@ -293,6 +314,303 @@ def test_auth_router_has_routes(self): assert "/evaluate" in routes +class TestProcessFileForTestset: + """Tests for the _process_file_for_testset helper function.""" + + @pytest.mark.asyncio + @patch("server.api.v1.testbed.utils_testbed.load_and_split") + @patch("server.api.v1.testbed.utils_testbed.build_knowledge_base") + async def test_process_file_writes_and_processes( + self, mock_build_kb, mock_load_split, tmp_path + ): + """_process_file_for_testset should write file and build knowledge base.""" + mock_load_split.return_value = ["node1", "node2"] + mock_testset = MagicMock() + + # Make save create an actual file (function reads it after save) + def save_side_effect(path): + with open(path, "w", encoding="utf-8") as f: + f.write('{"question": "generated"}\n') + + mock_testset.save = save_side_effect + mock_build_kb.return_value = mock_testset + + mock_file = MagicMock() + mock_file.read = AsyncMock(return_value=b"file content") + mock_file.filename = "test.pdf" + + full_testsets = tmp_path / "all_testsets.jsonl" + full_testsets.touch() + + await testbed._process_file_for_testset( + file=mock_file, + temp_directory=tmp_path, + full_testsets=full_testsets, + name="TestSet", + questions=5, + ll_model="gpt-4", + embed_model="text-embedding-3", + oci_config=MagicMock(), + ) + + mock_load_split.assert_called_once() + mock_build_kb.assert_called_once() + # Verify file was created (save was called) + assert (tmp_path / "TestSet.jsonl").exists() + + @pytest.mark.asyncio + @patch("server.api.v1.testbed.utils_testbed.load_and_split") + @patch("server.api.v1.testbed.utils_testbed.build_knowledge_base") + async def test_process_file_appends_to_full_testsets( + self, mock_build_kb, mock_load_split, tmp_path + ): + """_process_file_for_testset should append to full_testsets file.""" + mock_load_split.return_value = ["node1"] + mock_testset = MagicMock() + + def save_side_effect(path): + with open(path, "w", encoding="utf-8") as f: + f.write('{"question": "Q1"}\n') + + mock_testset.save = save_side_effect + mock_build_kb.return_value = mock_testset + + mock_file = MagicMock() + mock_file.read = AsyncMock(return_value=b"content") + mock_file.filename = "test.pdf" + + full_testsets = tmp_path / "all_testsets.jsonl" + full_testsets.write_text('{"question": "existing"}\n') + + await testbed._process_file_for_testset( + file=mock_file, + temp_directory=tmp_path, + full_testsets=full_testsets, + name="TestSet", + questions=2, + ll_model="gpt-4", + embed_model="embed", + oci_config=MagicMock(), + ) + + content = full_testsets.read_text() + assert '{"question": "existing"}' in content + assert '{"question": "Q1"}' in content + + +class TestCollectTestbedAnswers: + """Tests for the _collect_testbed_answers helper function.""" + + @pytest.mark.asyncio + @patch("server.api.v1.testbed.chat.chat_post") + async def test_collect_answers_returns_agent_answers(self, mock_chat_post): + """_collect_testbed_answers should return list of AgentAnswer objects.""" + mock_chat_post.return_value = { + "choices": [{"message": {"content": "Test response"}}] + } + + mock_df = MagicMock() + mock_df.itertuples.return_value = [ + MagicMock(question="Question 1"), + MagicMock(question="Question 2"), + ] + mock_testset = MagicMock() + mock_testset.to_pandas.return_value = mock_df + + result = await testbed._collect_testbed_answers(mock_testset, "test_client") + + assert len(result) == 2 + assert result[0].message == "Test response" + assert result[1].message == "Test response" + + @pytest.mark.asyncio + @patch("server.api.v1.testbed.chat.chat_post") + async def test_collect_answers_calls_chat_for_each_question(self, mock_chat_post): + """_collect_testbed_answers should call chat endpoint for each question.""" + mock_chat_post.return_value = { + "choices": [{"message": {"content": "Response"}}] + } + + mock_df = MagicMock() + mock_df.itertuples.return_value = [ + MagicMock(question="Q1"), + MagicMock(question="Q2"), + MagicMock(question="Q3"), + ] + mock_testset = MagicMock() + mock_testset.to_pandas.return_value = mock_df + + await testbed._collect_testbed_answers(mock_testset, "client123") + + assert mock_chat_post.call_count == 3 + + @pytest.mark.asyncio + @patch("server.api.v1.testbed.chat.chat_post") + async def test_collect_answers_empty_testset(self, mock_chat_post): + """_collect_testbed_answers should return empty list for empty testset.""" + mock_df = MagicMock() + mock_df.itertuples.return_value = [] + mock_testset = MagicMock() + mock_testset.to_pandas.return_value = mock_df + + result = await testbed._collect_testbed_answers(mock_testset, "client") + + assert result == [] + mock_chat_post.assert_not_called() + + +class TestTestbedEvaluate: + """Tests for the testbed_evaluate endpoint.""" + + @pytest.mark.asyncio + @patch("server.api.v1.testbed.pickle.dumps") + @patch("server.api.v1.testbed.utils_settings.get_client") + @patch("server.api.v1.testbed.utils_databases.get_client_database") + @patch("server.api.v1.testbed.utils_testbed.get_testset_qa") + @patch("server.api.v1.testbed.utils_embed.get_temp_directory") + @patch("server.api.v1.testbed.QATestset.load") + @patch("server.api.v1.testbed.utils_oci.get") + @patch("server.api.v1.testbed.utils_models.get_litellm_config") + @patch("server.api.v1.testbed.set_llm_model") + @patch("server.api.v1.testbed.get_prompt_with_override") + @patch("server.api.v1.testbed._collect_testbed_answers") + @patch("server.api.v1.testbed.evaluate") + @patch("server.api.v1.testbed.utils_testbed.insert_evaluation") + @patch("server.api.v1.testbed.utils_testbed.process_report") + @patch("server.api.v1.testbed.shutil.rmtree") + async def test_testbed_evaluate_success( + self, + _mock_rmtree, + mock_process_report, + mock_insert_eval, + mock_evaluate, + mock_collect_answers, + mock_get_prompt, + _mock_set_llm, + mock_get_litellm, + mock_oci_get, + mock_qa_load, + mock_get_temp_dir, + mock_get_testset_qa, + mock_get_db, + mock_get_settings, + mock_pickle_dumps, + mock_db_connection, + tmp_path, + ): + """testbed_evaluate should run evaluation and return report.""" + mock_pickle_dumps.return_value = b"pickled_report" + + mock_settings = MagicMock() + mock_settings.ll_model = MagicMock() + mock_settings.vector_search = MagicMock() + mock_settings.model_dump_json.return_value = "{}" + mock_get_settings.return_value = mock_settings + + mock_db = MagicMock() + mock_db.connection = mock_db_connection + mock_get_db.return_value = mock_db + + mock_get_testset_qa.return_value = MagicMock(qa_data=[{"q": "Q1", "a": "A1"}]) + mock_get_temp_dir.return_value = tmp_path + + mock_loaded_testset = MagicMock() + mock_qa_load.return_value = mock_loaded_testset + + mock_oci_get.return_value = MagicMock() + mock_get_litellm.return_value = {"api_key": "test"} + + mock_prompt_msg = MagicMock() + mock_prompt_msg.content.text = "You are a judge." + mock_get_prompt.return_value = mock_prompt_msg + + mock_collect_answers.return_value = [MagicMock(message="Answer")] + + mock_report = MagicMock() + mock_report.correctness = 0.85 + mock_evaluate.return_value = mock_report + + mock_insert_eval.return_value = "EID123" + + mock_eval_report = MagicMock() + mock_process_report.return_value = mock_eval_report + + result = await testbed.testbed_evaluate( + tid="TS001", + judge="gpt-4", + client="test_client", + ) + + assert result == mock_eval_report + mock_settings.ll_model.chat_history = False + mock_settings.vector_search.grade = False + mock_evaluate.assert_called_once() + mock_insert_eval.assert_called_once() + mock_db_connection.commit.assert_called() + + @pytest.mark.asyncio + @patch("server.api.v1.testbed.utils_settings.get_client") + @patch("server.api.v1.testbed.utils_databases.get_client_database") + @patch("server.api.v1.testbed.utils_testbed.get_testset_qa") + @patch("server.api.v1.testbed.utils_embed.get_temp_directory") + @patch("server.api.v1.testbed.QATestset.load") + @patch("server.api.v1.testbed.utils_oci.get") + @patch("server.api.v1.testbed.utils_models.get_litellm_config") + @patch("server.api.v1.testbed.set_llm_model") + @patch("server.api.v1.testbed.get_prompt_with_override") + @patch("server.api.v1.testbed._collect_testbed_answers") + @patch("server.api.v1.testbed.evaluate") + async def test_testbed_evaluate_raises_500_on_correctness_key_error( + self, + mock_evaluate, + mock_collect_answers, + mock_get_prompt, + _mock_set_llm, + mock_get_litellm, + mock_oci_get, + mock_qa_load, + mock_get_temp_dir, + mock_get_testset_qa, + mock_get_db, + mock_get_settings, + mock_db_connection, + tmp_path, + ): + """testbed_evaluate should raise 500 when correctness key is missing.""" + mock_settings = MagicMock() + mock_settings.ll_model = MagicMock() + mock_settings.vector_search = MagicMock() + mock_get_settings.return_value = mock_settings + + mock_db = MagicMock() + mock_db.connection = mock_db_connection + mock_get_db.return_value = mock_db + + mock_get_testset_qa.return_value = MagicMock(qa_data=[{"q": "Q1"}]) + mock_get_temp_dir.return_value = tmp_path + + mock_qa_load.return_value = MagicMock() + mock_oci_get.return_value = MagicMock() + mock_get_litellm.return_value = {} + + mock_prompt_msg = MagicMock() + mock_prompt_msg.content.text = "Judge prompt" + mock_get_prompt.return_value = mock_prompt_msg + + mock_collect_answers.return_value = [] + mock_evaluate.side_effect = KeyError("correctness") + + with pytest.raises(HTTPException) as exc_info: + await testbed.testbed_evaluate( + tid="TS001", + judge="gpt-4", + client="test_client", + ) + + assert exc_info.value.status_code == 500 + assert "correctness" in str(exc_info.value.detail) + + class TestLoggerConfiguration: """Tests for logger configuration.""" From 60ebbfaf0bde3b88cc85aec9f4e4e90093a3e071 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 19:07:13 +0000 Subject: [PATCH 11/20] Move OpenTofu tests --- {tests => test}/opentofu/OMRMetaSchema.yaml | 0 {tests => test}/opentofu/validate_omr_schema.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename {tests => test}/opentofu/OMRMetaSchema.yaml (100%) rename {tests => test}/opentofu/validate_omr_schema.py (100%) diff --git a/tests/opentofu/OMRMetaSchema.yaml b/test/opentofu/OMRMetaSchema.yaml similarity index 100% rename from tests/opentofu/OMRMetaSchema.yaml rename to test/opentofu/OMRMetaSchema.yaml diff --git a/tests/opentofu/validate_omr_schema.py b/test/opentofu/validate_omr_schema.py similarity index 100% rename from tests/opentofu/validate_omr_schema.py rename to test/opentofu/validate_omr_schema.py From 7230bab3ae3443df0a942c74de50d8143b1eca41 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 19:33:42 +0000 Subject: [PATCH 12/20] Initial NL2SQL --- src/server/agents/chatbot.py | 135 +++++++++++++++++++++++++++++++++-- src/server/api/utils/chat.py | 19 ++++- 2 files changed, 147 insertions(+), 7 deletions(-) diff --git a/src/server/agents/chatbot.py b/src/server/agents/chatbot.py index 16ff3ff5..b6d49ee1 100644 --- a/src/server/agents/chatbot.py +++ b/src/server/agents/chatbot.py @@ -74,11 +74,16 @@ def clean_messages(state: OptimizerState, config: RunnableConfig) -> list: return state_messages -def use_tool(_, config: RunnableConfig) -> Literal["vs_retrieve", "stream_completion"]: - """Conditional edge to determine if using Vector Search or not""" - enabled = "Vector Search" in config.get("metadata", {}).get("tools_enabled", []) - if enabled: - logger.info("Invoking Chatbot with Vector Search: %s", enabled) +def use_tool(_, config: RunnableConfig) -> Literal["vs_retrieve", "nl2sql", "stream_completion"]: + """Conditional edge to determine which tool path to use""" + tools_enabled = config.get("metadata", {}).get("tools_enabled", []) + + if "NL2SQL" in tools_enabled: + logger.info("Invoking Chatbot with NL2SQL") + return "nl2sql" + + if "Vector Search" in tools_enabled: + logger.info("Invoking Chatbot with Vector Search") return "vs_retrieve" return "stream_completion" @@ -237,6 +242,122 @@ async def vs_retrieve(state: OptimizerState, config: RunnableConfig) -> Optimize logger.info("Found Documents: %i", len(documents_dict)) return {"context_input": retrieve_question, "documents": documents_dict} +async def nl2sql(state: OptimizerState, config: RunnableConfig) -> OptimizerState: + """Execute NL2SQL tool calling flow with streaming response. + + This node handles the complete NL2SQL agentic loop: + 1. Calls the LLM with NL2SQL tools bound + 2. If the LLM requests tool calls, executes them via MCP + 3. Calls LLM again with tool results to generate final response + 4. Streams the final response to the client + """ + writer = get_stream_writer() + tools = config["metadata"].get("tools", []) + if not tools: + logger.warning("NL2SQL enabled but no tools provided") + return {} + + messages = copy.deepcopy(state["cleaned_messages"]) + sys_prompt_msg = default_prompts.get_prompt_with_override("optimizer_basic-default") + messages.insert(0, SystemMessage(content=sys_prompt_msg.content.text)) + + ll_raw = config["configurable"]["ll_config"] + mcp_client = config["metadata"].get("mcp_client") + streaming_enabled = config["metadata"].get("streaming", True) + + try: + # Agentic loop - continue until LLM produces final response (no tool calls) + max_iterations = 10 # Safety limit + iteration = 0 + + while iteration < max_iterations: + iteration += 1 + logger.info("NL2SQL: Iteration %d", iteration) + + # Call LLM with tools + response = await acompletion( + messages=convert_to_openai_messages(messages), + tools=tools, + stream=False, + **ll_raw, + ) + + choice = response.choices[0] + tool_calls = getattr(choice.message, "tool_calls", None) + + # If no tool calls, we have the final response + if not tool_calls: + content = choice.message.content or "" + logger.info("NL2SQL: Final response received (no tool calls)") + + # Stream the response if streaming is enabled + if streaming_enabled and content: + for char in content: + writer({"stream": char}) + + # Build completion response + response.object = "chat.completion" + writer({"completion": response.model_dump()}) + + return {"messages": [AIMessage(content=content)]} + + # Execute tool calls + logger.info("NL2SQL: Executing %d tool calls", len(tool_calls)) + + # Add assistant message with tool calls to conversation + ai_message = AIMessage( + content=choice.message.content or "", + tool_calls=[ + { + "name": tc.function.name, + "args": json.loads(tc.function.arguments) if tc.function.arguments else {}, + "id": tc.id, + "type": "tool_call", + } + for tc in tool_calls + ], + ) + messages.append(ai_message) + + # Execute each tool call and add results + for tool_call in tool_calls: + tool_name = tool_call.function.name + tool_args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {} + tool_id = tool_call.id + + logger.info("NL2SQL: Calling tool %s with args %s", tool_name, tool_args) + + try: + if mcp_client: + mcp_tools = await mcp_client.get_tools() + tool = next((t for t in mcp_tools if t.name == tool_name), None) + if tool: + result = await tool.ainvoke(tool_args) + if isinstance(result, dict): + result = json.dumps(result, indent=2, cls=DecimalEncoder) + elif not isinstance(result, str): + result = str(result) + else: + result = f"Unknown tool: {tool_name}" + else: + result = f"MCP client not available for tool: {tool_name}" + except Exception as ex: + logger.error("NL2SQL: Tool %s failed: %s", tool_name, ex) + result = f"Error executing {tool_name}: {str(ex)}" + + messages.append(ToolMessage(content=result, tool_call_id=tool_id, name=tool_name)) + + # If we hit max iterations, return what we have + logger.warning("NL2SQL: Max iterations reached") + return {"messages": [AIMessage(content="I'm sorry, I wasn't able to complete the request.")]} + + except APIConnectionError as ex: + logger.error("NL2SQL: API connection error: %s", ex) + return {"messages": [AIMessage(content="I'm not able to contact the model API.")]} + except Exception as ex: + logger.error("NL2SQL: Unexpected error: %s", ex) + return {"messages": [AIMessage(content=f"An error occurred: {str(ex)}")]} + def _build_system_prompt(state: OptimizerState, config: RunnableConfig) -> SystemMessage: """Build the system prompt based on vector search configuration.""" @@ -318,15 +439,17 @@ async def stream_completion(state: OptimizerState, config: RunnableConfig) -> Op workflow.add_node("rephrase", rephrase) workflow.add_node("vs_retrieve", vs_retrieve) workflow.add_node("vs_grade", vs_grade) +workflow.add_node("nl2sql", nl2sql) workflow.add_node("stream_completion", stream_completion) # Start the chatbot with clean messages workflow.add_edge(START, "initialise") -# Branch to either "vs_retrieve", or "stream_completion" +# Branch to "vs_retrieve", "nl2sql", or "stream_completion" workflow.add_conditional_edges("initialise", use_tool) workflow.add_edge("vs_retrieve", "vs_grade") workflow.add_edge("vs_grade", "stream_completion") +workflow.add_edge("nl2sql", END) # nl2sql handles its own completion and streaming # End the workflow workflow.add_edge("stream_completion", END) diff --git a/src/server/api/utils/chat.py b/src/server/api/utils/chat.py index 6976a101..7061eccb 100644 --- a/src/server/api/utils/chat.py +++ b/src/server/api/utils/chat.py @@ -7,11 +7,15 @@ from typing import Literal, AsyncGenerator from litellm import completion + from langchain_core.messages import HumanMessage from langchain_core.runnables import RunnableConfig +from langchain_core.utils.function_calling import convert_to_openai_function -import server.api.utils.settings as utils_settings +from langchain_mcp_adapters.client import MultiServerMCPClient +import server.api.utils.settings as utils_settings +import server.api.utils.mcp as utils_mcp import server.api.utils.oci as utils_oci import server.api.utils.models as utils_models import server.api.utils.databases as utils_databases @@ -80,6 +84,19 @@ async def completion_generator( client_settings.vector_search.model_dump(), oci_config ) + if "NL2SQL" in client_settings.tools_enabled: + mcp_client = MultiServerMCPClient( + {"optimizer": utils_mcp.get_client(client="langgraph")["mcpServers"]["optimizer"]} + ) + tools = await mcp_client.get_tools() + nl2sql_tools = [tool for tool in tools if tool.name.startswith("sqlcl_")] + # Convert LangChain tools to OpenAI Functions for binding to LiteLLM model + kwargs["config"]["metadata"]["tools"] = [ + {"type": "function", "function": convert_to_openai_function(t)} for t in nl2sql_tools + ] + # Pass MCP client for tool execution in nl2sql node + kwargs["config"]["metadata"]["mcp_client"] = mcp_client + logger.debug("Completion Kwargs: %s", kwargs) final_response = None async for output in chatbot_graph.astream(**kwargs): From b42e0fa9892589659d3d4059d5507f779d78ce57 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 19:55:20 +0000 Subject: [PATCH 13/20] comment the key usage --- src/client/content/config/tabs/mcp.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/client/content/config/tabs/mcp.py b/src/client/content/config/tabs/mcp.py index 484de2b9..b78b79a8 100644 --- a/src/client/content/config/tabs/mcp.py +++ b/src/client/content/config/tabs/mcp.py @@ -123,17 +123,19 @@ def render_configs(mcp_server: str, mcp_type: str, configs: list) -> None: col1.markdown("Name", unsafe_allow_html=True) col2.markdown("​") for mcp_name in configs: - col1.text_input( + # The key prefix is to give each widget a unique key in the loop; the key itself is never used + key_prefix = f"{mcp_server}_{mcp_type}_{mcp_name}" + col1.text( "Name", value=mcp_name, label_visibility="collapsed", disabled=True, - key=f"{mcp_server}_{mcp_type}_{mcp_name}_input", + key=f"{key_prefix}_name", ) col2.button( "Details", on_click=mcp_details, - key=f"{mcp_server}_{mcp_type}_{mcp_name}_details", + key=f"{key_prefix}_details", kwargs={"mcp_server": mcp_server, "mcp_type": mcp_type, "mcp_name": mcp_name}, ) From 48d3b2671a6a9c981f4a6e9b1397ad59785eba67 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 20:46:55 +0000 Subject: [PATCH 14/20] Fix to rephrase --- src/client/content/chatbot.py | 81 +++++++++++++++++++++++++++++------ src/server/agents/chatbot.py | 13 ++++-- 2 files changed, 78 insertions(+), 16 deletions(-) diff --git a/src/client/content/chatbot.py b/src/client/content/chatbot.py index f9673581..cd49ee9c 100644 --- a/src/client/content/chatbot.py +++ b/src/client/content/chatbot.py @@ -26,15 +26,15 @@ ############################################################################# # Functions ############################################################################# -def show_vector_search_refs(context): +def show_vector_search_refs(context, vs_metadata=None): """When Vector Search Content Found, show the references""" st.markdown("**References:**") ref_src = set() ref_cols = st.columns([3, 3, 3]) # Create a button in each column - for i, (ref_col, chunk) in enumerate(zip(ref_cols, context[0])): + for i, (ref_col, chunk) in enumerate(zip(ref_cols, context["documents"])): with ref_col.popover(f"Reference: {i + 1}"): - chunk = context[0][i] + chunk = context["documents"][i] logger.debug("Chunk Content: %s", chunk) st.subheader("Reference Text", divider="red") st.markdown(chunk["page_content"]) @@ -46,9 +46,32 @@ def show_vector_search_refs(context): except KeyError: logger.error("Chunk Metadata NOT FOUND!!") - for link in ref_src: - st.markdown("- " + link) - st.markdown(f"**Notes:** Vector Search Query - {context[1]}") + # Display Vector Search details in expander + if vs_metadata or ref_src: + with st.expander("Vector Search Details", expanded=False): + if ref_src: + st.markdown("**Source Documents:**") + for link in ref_src: + st.markdown(f"- {link}") + + if vs_metadata and vs_metadata.get("searched_tables"): + st.markdown("**Tables Searched:**") + for table in vs_metadata["searched_tables"]: + st.markdown(f"- {table}") + + if vs_metadata and vs_metadata.get("context_input"): + st.markdown(f"**Search Query:** {vs_metadata.get('context_input')}") + elif context.get("context_input"): + st.markdown(f"**Search Query:** {context.get('context_input')}") + + +def show_token_usage(token_usage): + """Display token usage for AI responses using caption""" + if token_usage: + prompt_tokens = token_usage.get("prompt_tokens", 0) + completion_tokens = token_usage.get("completion_tokens", 0) + total_tokens = token_usage.get("total_tokens", 0) + st.caption(f"Token usage: {prompt_tokens} prompt + {completion_tokens} completion = {total_tokens} total") def setup_sidebar(): @@ -80,7 +103,7 @@ def create_client(): def display_chat_history(history): - """Display chat history messages""" + """Display chat history messages with metadata""" st.chat_message("ai").write("Hello, how can I help you?") vector_search_refs = [] @@ -88,14 +111,25 @@ def display_chat_history(history): if not message["content"]: continue - if message["role"] == "tool" and message["name"] == "oraclevs_tool": + if message["role"] == "tool" and message["name"] == "optimizer_vs-retriever": vector_search_refs = json.loads(message["content"]) elif message["role"] in ("ai", "assistant"): with st.chat_message("ai"): st.markdown(message["content"]) + + # Extract metadata from response_metadata + response_metadata = message.get("response_metadata", {}) + vs_metadata = response_metadata.get("vs_metadata", {}) + token_usage = response_metadata.get("token_usage", {}) + + # Show token usage immediately after message + if token_usage: + show_token_usage(token_usage) + + # Show vector search references if available if vector_search_refs: - show_vector_search_refs(vector_search_refs) + show_vector_search_refs(vector_search_refs, vs_metadata) vector_search_refs = [] elif message["role"] in ("human", "user"): @@ -131,9 +165,32 @@ async def handle_chat_input(user_client): try: message_placeholder = st.chat_message("ai").empty() full_answer = "" - async for chunk in user_client.stream(message=human_request.text, image_b64=file_b64): - full_answer += chunk - message_placeholder.markdown(full_answer) + + # Animated thinking indicator + async def animate_thinking(): + """Animate the thinking indicator with increasing dots""" + dots = 0 + while True: + message_placeholder.markdown(f"🤔 Thinking{'.' * (dots % 4)}") + dots += 1 + await asyncio.sleep(0.5) # Update every 500ms + + # Start the thinking animation + thinking_task = asyncio.create_task(animate_thinking()) + + try: + async for chunk in user_client.stream(message=human_request.text, image_b64=file_b64): + # Cancel thinking animation on first chunk + if thinking_task and not thinking_task.done(): + thinking_task.cancel() + thinking_task = None + full_answer += chunk + message_placeholder.markdown(full_answer) + finally: + # Ensure thinking task is cancelled + if thinking_task and not thinking_task.done(): + thinking_task.cancel() + st.rerun() except (ConnectionError, TimeoutError, api_call.ApiError) as ex: logger.exception("Error during chat streaming: %s", ex) diff --git a/src/server/agents/chatbot.py b/src/server/agents/chatbot.py index b6d49ee1..d7b59fae 100644 --- a/src/server/agents/chatbot.py +++ b/src/server/agents/chatbot.py @@ -97,12 +97,15 @@ def rephrase(state: OptimizerState, config: RunnableConfig) -> str: rephrase_prompt_msg = default_prompts.get_prompt_with_override("optimizer_vs-rephrase") rephrase_template_text = rephrase_prompt_msg.content.text + context_prompt_msg = default_prompts.get_prompt_with_override("optimizer_context-default") + context_prompt_text = context_prompt_msg.content.text + rephrase_template = PromptTemplate( template=rephrase_template_text, - input_variables=["ctx_prompt", "history", "question"], + input_variables=["prompt", "history", "question"], ) formatted_prompt = rephrase_template.format( - prompt=rephrase_template_text, history=state["messages"], question=retrieve_question + prompt=context_prompt_text, history=state["messages"], question=retrieve_question ) ll_raw = config["configurable"]["ll_config"] try: @@ -177,8 +180,10 @@ async def vs_grade(state: OptimizerState, config: RunnableConfig) -> OptimizerSt state["messages"].append( ToolMessage( - content=json.dumps([state["documents"], state["context_input"]], cls=DecimalEncoder), - name="oraclevs_tool", + content=json.dumps( + {"documents": state["documents"], "context_input": state["context_input"]}, cls=DecimalEncoder + ), + name="optimizer_vs-retriever", tool_call_id="tool_placeholder", ) ) From 9f89e961ac70e4439b471da313028f716fc65a10 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 20:51:52 +0000 Subject: [PATCH 15/20] wrong widget --- src/client/content/config/tabs/mcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/client/content/config/tabs/mcp.py b/src/client/content/config/tabs/mcp.py index b78b79a8..32760515 100644 --- a/src/client/content/config/tabs/mcp.py +++ b/src/client/content/config/tabs/mcp.py @@ -125,7 +125,7 @@ def render_configs(mcp_server: str, mcp_type: str, configs: list) -> None: for mcp_name in configs: # The key prefix is to give each widget a unique key in the loop; the key itself is never used key_prefix = f"{mcp_server}_{mcp_type}_{mcp_name}" - col1.text( + col1.text_input( "Name", value=mcp_name, label_visibility="collapsed", From c4707b5a4a4661448dbbc859a14dc153fbc60794 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 20:59:44 +0000 Subject: [PATCH 16/20] moved test refresh to new branch --- test/__init__.py | 1 - test/conftest.py | 27 - test/db_fixtures.py | 209 ----- test/integration/__init__.py | 6 - test/integration/server/__init__.py | 6 - test/integration/server/api/__init__.py | 6 - test/integration/server/api/conftest.py | 162 ---- test/integration/server/api/utils/__init__.py | 6 - test/integration/server/api/v1/__init__.py | 6 - .../server/api/v1/test_databases.py | 153 ---- test/integration/server/api/v1/test_models.py | 271 ------ test/integration/server/api/v1/test_oci.py | 224 ----- test/integration/server/api/v1/test_probes.py | 74 -- .../server/api/v1/test_settings.py | 307 ------- test/integration/server/bootstrap/__init__.py | 1 - test/integration/server/bootstrap/conftest.py | 147 ---- .../bootstrap/test_bootstrap_configfile.py | 245 ------ .../bootstrap/test_bootstrap_databases.py | 197 ----- .../server/bootstrap/test_bootstrap_models.py | 263 ------ .../server/bootstrap/test_bootstrap_oci.py | 246 ------ .../bootstrap/test_bootstrap_settings.py | 170 ---- test/shared_fixtures.py | 330 ------- test/unit/__init__.py | 1 - test/unit/server/__init__.py | 1 - test/unit/server/api/__init__.py | 1 - test/unit/server/api/conftest.py | 194 ----- test/unit/server/api/utils/__init__.py | 1 - test/unit/server/api/utils/test_utils_chat.py | 312 ------- .../server/api/utils/test_utils_databases.py | 657 -------------- .../unit/server/api/utils/test_utils_embed.py | 805 ------------------ test/unit/server/api/utils/test_utils_mcp.py | 192 ----- .../server/api/utils/test_utils_models.py | 433 ---------- test/unit/server/api/utils/test_utils_oci.py | 595 ------------- .../server/api/utils/test_utils_settings.py | 352 -------- .../server/api/utils/test_utils_testbed.py | 324 ------- .../api/utils/test_utils_testbed_metrics.py | 345 -------- .../server/api/utils/test_utils_webscrape.py | 419 --------- test/unit/server/api/v1/__init__.py | 1 - test/unit/server/api/v1/test_v1_chat.py | 258 ------ test/unit/server/api/v1/test_v1_databases.py | 290 ------- test/unit/server/api/v1/test_v1_embed.py | 553 ------------ test/unit/server/api/v1/test_v1_mcp.py | 169 ---- .../unit/server/api/v1/test_v1_mcp_prompts.py | 229 ----- test/unit/server/api/v1/test_v1_models.py | 254 ------ test/unit/server/api/v1/test_v1_oci.py | 362 -------- test/unit/server/api/v1/test_v1_probes.py | 129 --- test/unit/server/api/v1/test_v1_settings.py | 326 ------- test/unit/server/api/v1/test_v1_testbed.py | 623 -------------- test/unit/server/bootstrap/__init__.py | 1 - test/unit/server/bootstrap/conftest.py | 57 -- .../bootstrap/test_bootstrap_bootstrap.py | 183 ---- .../bootstrap/test_bootstrap_configfile.py | 229 ----- .../bootstrap/test_bootstrap_databases.py | 219 ----- .../server/bootstrap/test_bootstrap_models.py | 413 --------- .../server/bootstrap/test_bootstrap_oci.py | 329 ------- .../bootstrap/test_bootstrap_settings.py | 143 ---- .../content/tools/tabs/test_split_embed.py | 10 +- .../integration/utils/test_st_common.py | 327 ++++++- .../client/unit/content/test_chatbot_unit.py | 8 +- .../tools/tabs/test_split_embed_unit.py | 19 +- .../client/unit/utils/test_st_common_unit.py | 278 +++++- {test => tests}/opentofu/OMRMetaSchema.yaml | 0 .../opentofu/validate_omr_schema.py | 0 .../integration/test_endpoints_settings.py | 7 +- .../server/unit/api/utils/test_utils_chat.py | 52 +- .../api/utils/test_utils_databases_crud.py | 124 ++- .../utils/test_utils_databases_functions.py | 236 ++++- .../server/unit/api/utils/test_utils_embed.py | 53 +- .../unit/api/utils/test_utils_models.py | 254 ++++-- tests/server/unit/api/utils/test_utils_oci.py | 268 ++++-- .../unit/api/utils/test_utils_oci_refresh.py | 25 +- .../unit/api/utils/test_utils_settings.py | 225 ++++- .../unit/api/utils/test_utils_testbed.py | 17 +- tests/server/unit/bootstrap/test_bootstrap.py | 60 +- 74 files changed, 1670 insertions(+), 12750 deletions(-) delete mode 100644 test/__init__.py delete mode 100644 test/conftest.py delete mode 100644 test/db_fixtures.py delete mode 100644 test/integration/__init__.py delete mode 100644 test/integration/server/__init__.py delete mode 100644 test/integration/server/api/__init__.py delete mode 100644 test/integration/server/api/conftest.py delete mode 100644 test/integration/server/api/utils/__init__.py delete mode 100644 test/integration/server/api/v1/__init__.py delete mode 100644 test/integration/server/api/v1/test_databases.py delete mode 100644 test/integration/server/api/v1/test_models.py delete mode 100644 test/integration/server/api/v1/test_oci.py delete mode 100644 test/integration/server/api/v1/test_probes.py delete mode 100644 test/integration/server/api/v1/test_settings.py delete mode 100644 test/integration/server/bootstrap/__init__.py delete mode 100644 test/integration/server/bootstrap/conftest.py delete mode 100644 test/integration/server/bootstrap/test_bootstrap_configfile.py delete mode 100644 test/integration/server/bootstrap/test_bootstrap_databases.py delete mode 100644 test/integration/server/bootstrap/test_bootstrap_models.py delete mode 100644 test/integration/server/bootstrap/test_bootstrap_oci.py delete mode 100644 test/integration/server/bootstrap/test_bootstrap_settings.py delete mode 100644 test/shared_fixtures.py delete mode 100644 test/unit/__init__.py delete mode 100644 test/unit/server/__init__.py delete mode 100644 test/unit/server/api/__init__.py delete mode 100644 test/unit/server/api/conftest.py delete mode 100644 test/unit/server/api/utils/__init__.py delete mode 100644 test/unit/server/api/utils/test_utils_chat.py delete mode 100644 test/unit/server/api/utils/test_utils_databases.py delete mode 100644 test/unit/server/api/utils/test_utils_embed.py delete mode 100644 test/unit/server/api/utils/test_utils_mcp.py delete mode 100644 test/unit/server/api/utils/test_utils_models.py delete mode 100644 test/unit/server/api/utils/test_utils_oci.py delete mode 100644 test/unit/server/api/utils/test_utils_settings.py delete mode 100644 test/unit/server/api/utils/test_utils_testbed.py delete mode 100644 test/unit/server/api/utils/test_utils_testbed_metrics.py delete mode 100644 test/unit/server/api/utils/test_utils_webscrape.py delete mode 100644 test/unit/server/api/v1/__init__.py delete mode 100644 test/unit/server/api/v1/test_v1_chat.py delete mode 100644 test/unit/server/api/v1/test_v1_databases.py delete mode 100644 test/unit/server/api/v1/test_v1_embed.py delete mode 100644 test/unit/server/api/v1/test_v1_mcp.py delete mode 100644 test/unit/server/api/v1/test_v1_mcp_prompts.py delete mode 100644 test/unit/server/api/v1/test_v1_models.py delete mode 100644 test/unit/server/api/v1/test_v1_oci.py delete mode 100644 test/unit/server/api/v1/test_v1_probes.py delete mode 100644 test/unit/server/api/v1/test_v1_settings.py delete mode 100644 test/unit/server/api/v1/test_v1_testbed.py delete mode 100644 test/unit/server/bootstrap/__init__.py delete mode 100644 test/unit/server/bootstrap/conftest.py delete mode 100644 test/unit/server/bootstrap/test_bootstrap_bootstrap.py delete mode 100644 test/unit/server/bootstrap/test_bootstrap_configfile.py delete mode 100644 test/unit/server/bootstrap/test_bootstrap_databases.py delete mode 100644 test/unit/server/bootstrap/test_bootstrap_models.py delete mode 100644 test/unit/server/bootstrap/test_bootstrap_oci.py delete mode 100644 test/unit/server/bootstrap/test_bootstrap_settings.py rename {test => tests}/opentofu/OMRMetaSchema.yaml (100%) rename {test => tests}/opentofu/validate_omr_schema.py (100%) diff --git a/test/__init__.py b/test/__init__.py deleted file mode 100644 index 66173aec..00000000 --- a/test/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Test package diff --git a/test/conftest.py b/test/conftest.py deleted file mode 100644 index c25db9f6..00000000 --- a/test/conftest.py +++ /dev/null @@ -1,27 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Pytest fixtures for unit tests with real Oracle database. - -Re-exports shared database fixtures from test.db_fixtures. -""" - -# Re-export shared fixtures for pytest discovery -from test.db_fixtures import ( - TEST_DB_CONFIG, - db_container, - db_connection, - db_transaction, -) - -# Expose TEST_CONFIG alias for backwards compatibility -TEST_CONFIG = TEST_DB_CONFIG - -__all__ = [ - "TEST_CONFIG", - "TEST_DB_CONFIG", - "db_container", - "db_connection", - "db_transaction", -] diff --git a/test/db_fixtures.py b/test/db_fixtures.py deleted file mode 100644 index 51609173..00000000 --- a/test/db_fixtures.py +++ /dev/null @@ -1,209 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Shared database fixtures and utilities for tests. - -This module provides common database container management functions -used by both unit and integration tests. -""" - -# pylint: disable=redefined-outer-name -# Pytest fixtures use parameter injection where fixture names match parameters - -import time -import shutil -from pathlib import Path -from typing import Generator, Optional -from contextlib import contextmanager - -import pytest -import oracledb -import docker -from docker.errors import DockerException -from docker.models.containers import Container - - -# Test database configuration - shared across all tests -TEST_DB_CONFIG = { - "db_username": "PYTEST", - "db_password": "OrA_41_3xPl0d3r", - "db_dsn": "//localhost:1525/FREEPDB1", -} - - -def wait_for_container_ready( - container: Container, - ready_output: str, - since: Optional[int] = None, - timeout: int = 120, -) -> None: - """Wait for container to be ready by checking its logs with exponential backoff. - - Args: - container: Docker container to monitor - ready_output: String to look for in logs indicating readiness - since: Unix timestamp to filter logs from (optional) - timeout: Maximum seconds to wait (default 120) - - Raises: - TimeoutError: If container doesn't become ready within timeout - DockerException: If there's an error getting container logs - """ - start_time = time.time() - retry_interval = 2 - - while time.time() - start_time < timeout: - try: - logs = container.logs(tail=100, since=since).decode("utf-8") - if ready_output in logs: - return - except DockerException as e: - container.remove(force=True) - raise DockerException(f"Failed to get container logs: {str(e)}") from e - - time.sleep(retry_interval) - retry_interval = min(retry_interval * 2, 10) - - container.remove(force=True) - raise TimeoutError("Container did not become ready within timeout") - - -@contextmanager -def temp_sql_setup(temp_dir_path: str = "test/db_startup_temp"): - """Context manager for temporary SQL setup files. - - Creates a temporary directory with SQL initialization scripts - for the Oracle container. - - Args: - temp_dir_path: Path for temporary directory - - Yields: - Path object to the temporary directory - """ - temp_dir = Path(temp_dir_path) - try: - temp_dir.mkdir(exist_ok=True) - sql_content = f""" - alter system set vector_memory_size=512M scope=spfile; - - alter session set container=FREEPDB1; - CREATE TABLESPACE IF NOT EXISTS USERS DATAFILE '/opt/oracle/oradata/FREE/FREEPDB1/users_01.dbf' SIZE 100M; - CREATE USER IF NOT EXISTS "{TEST_DB_CONFIG["db_username"]}" IDENTIFIED BY {TEST_DB_CONFIG["db_password"]} - DEFAULT TABLESPACE "USERS" - TEMPORARY TABLESPACE "TEMP"; - GRANT "DB_DEVELOPER_ROLE" TO "{TEST_DB_CONFIG["db_username"]}"; - ALTER USER "{TEST_DB_CONFIG["db_username"]}" DEFAULT ROLE ALL; - ALTER USER "{TEST_DB_CONFIG["db_username"]}" QUOTA UNLIMITED ON USERS; - - EXIT; - """ - - temp_sql_file = temp_dir / "01_db_user.sql" - temp_sql_file.write_text(sql_content, encoding="UTF-8") - yield temp_dir - finally: - if temp_dir.exists(): - shutil.rmtree(temp_dir) - - -def create_db_container(temp_dir_name: str = "test/db_startup_temp") -> Generator[Container, None, None]: - """Create and manage an Oracle database container for testing. - - This generator function handles the full lifecycle of a Docker-based - Oracle database container for testing purposes. - - Args: - temp_dir_name: Path for temporary SQL setup files - - Yields: - Docker Container object for the running database - - Raises: - DockerException: If Docker operations fail - """ - db_client = docker.from_env() - container = None - - try: - with temp_sql_setup(temp_dir_name) as temp_dir: - container = db_client.containers.run( - "container-registry.oracle.com/database/free:latest-lite", - environment={ - "ORACLE_PWD": TEST_DB_CONFIG["db_password"], - "ORACLE_PDB": TEST_DB_CONFIG["db_dsn"].rsplit("/", maxsplit=1)[-1], - }, - ports={"1521/tcp": int(TEST_DB_CONFIG["db_dsn"].split(":")[1].split("/")[0])}, - volumes={str(temp_dir.absolute()): {"bind": "/opt/oracle/scripts/startup", "mode": "ro"}}, - detach=True, - ) - - # Wait for database to be ready - wait_for_container_ready(container, "DATABASE IS READY TO USE!") - - # Restart container to apply vector_memory_size - container.restart() - restart_time = int(time.time()) - wait_for_container_ready(container, "DATABASE IS READY TO USE!", since=restart_time) - - yield container - - except DockerException as e: - if container: - container.remove(force=True) - raise DockerException(f"Docker operation failed: {str(e)}") from e - - finally: - if container: - try: - container.stop(timeout=30) - container.remove() - except DockerException as e: - print(f"Warning: Failed to cleanup database container: {str(e)}") - - -@pytest.fixture(scope="session") -def db_container() -> Generator[Container, None, None]: - """Pytest fixture for Oracle database container. - - Session-scoped fixture that creates and manages an Oracle database - container for the duration of the test session. - """ - yield from create_db_container() - - -@pytest.fixture(scope="session") -def db_connection(db_container) -> Generator[oracledb.Connection, None, None]: - """Session-scoped real Oracle database connection. - - Depends on db_container to ensure database is running. - """ - _ = db_container # Ensure container is running - conn = oracledb.connect( - user=TEST_DB_CONFIG["db_username"], - password=TEST_DB_CONFIG["db_password"], - dsn=TEST_DB_CONFIG["db_dsn"], - ) - yield conn - conn.close() - - -@pytest.fixture -def db_transaction(db_connection) -> Generator[oracledb.Connection, None, None]: - """Transaction isolation for each test using savepoints. - - Creates a savepoint before each test and rolls back after, - ensuring tests don't affect each other's database state. - - Note: DDL operations (CREATE TABLE, etc.) cause implicit commits - in Oracle, which will invalidate the savepoint. Tests with DDL - should use mocks or handle cleanup manually. - """ - cursor = db_connection.cursor() - cursor.execute("SAVEPOINT test_savepoint") - - yield db_connection - - cursor.execute("ROLLBACK TO SAVEPOINT test_savepoint") - cursor.close() diff --git a/test/integration/__init__.py b/test/integration/__init__.py deleted file mode 100644 index 2577126f..00000000 --- a/test/integration/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Integration tests package. -""" diff --git a/test/integration/server/__init__.py b/test/integration/server/__init__.py deleted file mode 100644 index a242d937..00000000 --- a/test/integration/server/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Server integration tests package. -""" diff --git a/test/integration/server/api/__init__.py b/test/integration/server/api/__init__.py deleted file mode 100644 index c4b92db3..00000000 --- a/test/integration/server/api/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Server API integration tests package. -""" diff --git a/test/integration/server/api/conftest.py b/test/integration/server/api/conftest.py deleted file mode 100644 index 6b25297f..00000000 --- a/test/integration/server/api/conftest.py +++ /dev/null @@ -1,162 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Pytest fixtures for server API integration tests. - -Integration tests use a real FastAPI TestClient with the actual application, -testing the full request/response cycle through the API layer. - -Note: db_container fixture is inherited from test/conftest.py - do not import here. -""" - -# pylint: disable=redefined-outer-name unused-import -# Pytest fixtures use parameter injection where fixture names match parameters - -import os -import asyncio -from typing import Generator - -# Re-export shared fixtures for pytest discovery (before third-party imports per pylint) -from test.db_fixtures import TEST_DB_CONFIG -from test.shared_fixtures import ( - make_database, - make_model, - DEFAULT_LL_MODEL_CONFIG, -) - -import pytest -from fastapi.testclient import TestClient - -from server.bootstrap.bootstrap import DATABASE_OBJECTS, MODEL_OBJECTS, SETTINGS_OBJECTS - - -# Clear environment variables that could interfere with tests -# This must happen before importing application modules -API_VARS = ["API_SERVER_KEY", "API_SERVER_URL", "API_SERVER_PORT"] -DB_VARS = ["DB_USERNAME", "DB_PASSWORD", "DB_DSN", "DB_WALLET_PASSWORD", "TNS_ADMIN"] -MODEL_VARS = ["ON_PREM_OLLAMA_URL", "ON_PREM_HF_URL", "OPENAI_API_KEY", "PPLX_API_KEY", "COHERE_API_KEY"] -for env_var in [*API_VARS, *DB_VARS, *MODEL_VARS, *[var for var in os.environ if var.startswith("OCI_")]]: - os.environ.pop(env_var, None) - -# Test configuration - extends shared DB config with integration-specific settings -TEST_CONFIG = { - "client": "integration_test", - "auth_token": "integration-test-token", - **TEST_DB_CONFIG, -} - -# Set environment variables for test server -os.environ["CONFIG_FILE"] = "/non/existent/path/config.json" # Use empty config -os.environ["OCI_CLI_CONFIG_FILE"] = "/non/existent/path" # Prevent OCI config pickup -os.environ["API_SERVER_KEY"] = TEST_CONFIG["auth_token"] - - -################################################# -# Authentication Headers -################################################# -@pytest.fixture -def auth_headers(): - """Return common header configurations for testing.""" - return { - "no_auth": {}, - "invalid_auth": {"Authorization": "Bearer invalid-token", "client": TEST_CONFIG["client"]}, - "valid_auth": {"Authorization": f"Bearer {TEST_CONFIG['auth_token']}", "client": TEST_CONFIG["client"]}, - } - - -################################################# -# FastAPI Test Client -################################################# -@pytest.fixture(scope="session") -def app(): - """Create the FastAPI application for testing. - - This fixture creates the actual FastAPI app using the same factory - function as the production server (launch_server.create_app). - - Import is done inside the fixture to ensure environment variables - are set before any application modules are loaded. - """ - # pylint: disable=import-outside-toplevel - from launch_server import create_app - - return asyncio.run(create_app()) - - -@pytest.fixture(scope="session") -def client(app) -> Generator[TestClient, None, None]: - """Create a TestClient for the FastAPI app. - - The TestClient allows making HTTP requests to the app without - starting a real server, enabling fast integration testing. - """ - with TestClient(app) as test_client: - yield test_client - - -################################################# -# Test Data Helpers -################################################# -@pytest.fixture -def test_db_payload(): - """Get standard test database payload for integration tests.""" - return { - "user": TEST_CONFIG["db_username"], - "password": TEST_CONFIG["db_password"], - "dsn": TEST_CONFIG["db_dsn"], - } - - -@pytest.fixture -def sample_model_payload(): - """Sample model configuration for testing.""" - return { - "id": "test-model", - "type": "ll", - "provider": "openai", - "enabled": True, - } - - -@pytest.fixture -def sample_settings_payload(): - """Sample settings configuration for testing.""" - return { - "client": TEST_CONFIG["client"], - "ll_model": DEFAULT_LL_MODEL_CONFIG.copy(), - } - - -################################################# -# State Management Helpers -################################################# -@pytest.fixture -def db_objects_manager(): - """Fixture to manage DATABASE_OBJECTS save/restore operations. - - This fixture saves the current state of DATABASE_OBJECTS before each test - and restores it afterward, ensuring tests don't affect each other. - """ - original_db_objects = DATABASE_OBJECTS.copy() - yield DATABASE_OBJECTS - DATABASE_OBJECTS.clear() - DATABASE_OBJECTS.extend(original_db_objects) - - -@pytest.fixture -def model_objects_manager(): - """Fixture to manage MODEL_OBJECTS save/restore operations.""" - original_model_objects = MODEL_OBJECTS.copy() - yield MODEL_OBJECTS - MODEL_OBJECTS.clear() - MODEL_OBJECTS.extend(original_model_objects) - - -@pytest.fixture -def settings_objects_manager(): - """Fixture to manage SETTINGS_OBJECTS save/restore operations.""" - original_settings_objects = SETTINGS_OBJECTS.copy() - yield SETTINGS_OBJECTS - SETTINGS_OBJECTS.clear() - SETTINGS_OBJECTS.extend(original_settings_objects) diff --git a/test/integration/server/api/utils/__init__.py b/test/integration/server/api/utils/__init__.py deleted file mode 100644 index 37340b95..00000000 --- a/test/integration/server/api/utils/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Server API utils integration tests package. -""" diff --git a/test/integration/server/api/v1/__init__.py b/test/integration/server/api/v1/__init__.py deleted file mode 100644 index d55308b1..00000000 --- a/test/integration/server/api/v1/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Server API v1 integration tests package. -""" diff --git a/test/integration/server/api/v1/test_databases.py b/test/integration/server/api/v1/test_databases.py deleted file mode 100644 index 847ef735..00000000 --- a/test/integration/server/api/v1/test_databases.py +++ /dev/null @@ -1,153 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Integration tests for server/api/v1/databases.py - -Tests the database configuration endpoints through the full API stack. -These endpoints require authentication. -""" - - -class TestAuthentication: - """Integration tests for authentication on database endpoints.""" - - def test_databases_list_requires_auth(self, client): - """GET /v1/databases should require authentication.""" - response = client.get("/v1/databases") - - assert response.status_code == 401 # No auth header = Unauthorized - - def test_databases_list_rejects_invalid_token(self, client, auth_headers): - """GET /v1/databases should reject invalid tokens.""" - response = client.get("/v1/databases", headers=auth_headers["invalid_auth"]) - - assert response.status_code == 401 - - def test_databases_list_accepts_valid_token(self, client, auth_headers): - """GET /v1/databases should accept valid tokens.""" - response = client.get("/v1/databases", headers=auth_headers["valid_auth"]) - - assert response.status_code == 200 - - -class TestDatabasesList: - """Integration tests for the databases list endpoint.""" - - def test_databases_list_returns_list(self, client, auth_headers): - """GET /v1/databases should return a list of databases.""" - response = client.get("/v1/databases", headers=auth_headers["valid_auth"]) - - assert response.status_code == 200 - data = response.json() - assert isinstance(data, list) - - def test_databases_list_contains_default(self, client, auth_headers): - """GET /v1/databases should contain a DEFAULT database.""" - response = client.get("/v1/databases", headers=auth_headers["valid_auth"]) - - data = response.json() - # There should be at least one database (DEFAULT is created by bootstrap) - # If no config file, the list may be empty or contain DEFAULT - assert isinstance(data, list) - - def test_databases_list_returns_database_schema(self, client, auth_headers, db_objects_manager, make_database): - """GET /v1/databases should return databases with correct schema.""" - # Ensure there's at least one database for testing - if not db_objects_manager: - db_objects_manager.append(make_database()) - - response = client.get("/v1/databases", headers=auth_headers["valid_auth"]) - - assert response.status_code == 200 - data = response.json() - if data: - db = data[0] - assert "name" in db - assert "user" in db - assert "dsn" in db - assert "connected" in db - - -class TestDatabasesGet: - """Integration tests for the single database get endpoint.""" - - def test_databases_get_requires_auth(self, client): - """GET /v1/databases/{name} should require authentication.""" - response = client.get("/v1/databases/DEFAULT") - - assert response.status_code == 401 - - def test_databases_get_returns_404_for_unknown(self, client, auth_headers): - """GET /v1/databases/{name} should return 404 for unknown database.""" - response = client.get("/v1/databases/NONEXISTENT_DB", headers=auth_headers["valid_auth"]) - - assert response.status_code == 404 - - def test_databases_get_returns_database(self, client, auth_headers, db_objects_manager, make_database): - """GET /v1/databases/{name} should return the specified database.""" - # Ensure there's a test database - test_db = make_database(name="INTEGRATION_TEST_DB") - db_objects_manager.append(test_db) - - response = client.get("/v1/databases/INTEGRATION_TEST_DB", headers=auth_headers["valid_auth"]) - - assert response.status_code == 200 - data = response.json() - assert data["name"] == "INTEGRATION_TEST_DB" - - -class TestDatabasesUpdate: - """Integration tests for the database update endpoint.""" - - def test_databases_update_requires_auth(self, client): - """PATCH /v1/databases/{name} should require authentication.""" - response = client.patch("/v1/databases/DEFAULT", json={"user": "test"}) - - assert response.status_code == 401 - - def test_databases_update_returns_404_for_unknown(self, client, auth_headers): - """PATCH /v1/databases/{name} should return 404 for unknown database.""" - response = client.patch( - "/v1/databases/NONEXISTENT_DB", - headers=auth_headers["valid_auth"], - json={"user": "test", "password": "test", "dsn": "localhost:1521/TEST"}, - ) - - assert response.status_code == 404 - - def test_databases_update_validates_connection(self, client, auth_headers, db_objects_manager, make_database): - """PATCH /v1/databases/{name} should validate connection details.""" - # Add a test database - test_db = make_database(name="UPDATE_TEST_DB") - db_objects_manager.append(test_db) - - # Try to update with invalid connection details (no real DB running) - response = client.patch( - "/v1/databases/UPDATE_TEST_DB", - headers=auth_headers["valid_auth"], - json={"user": "invalid", "password": "invalid", "dsn": "localhost:9999/INVALID"}, - ) - - # Should fail because it tries to connect - assert response.status_code in [400, 401, 404, 503] - - def test_databases_update_connects_to_real_db( - self, client, auth_headers, db_objects_manager, db_container, test_db_payload, make_database - ): - """PATCH /v1/databases/{name} should connect to real database.""" - _ = db_container # Ensure container is running - # Add a test database - test_db = make_database(name="REAL_DB_TEST", user="placeholder", password="placeholder", dsn="placeholder") - db_objects_manager.append(test_db) - - response = client.patch( - "/v1/databases/REAL_DB_TEST", - headers=auth_headers["valid_auth"], - json=test_db_payload, - ) - - assert response.status_code == 200 - data = response.json() - assert data["connected"] is True - assert data["user"] == test_db_payload["user"] diff --git a/test/integration/server/api/v1/test_models.py b/test/integration/server/api/v1/test_models.py deleted file mode 100644 index 74cbd11b..00000000 --- a/test/integration/server/api/v1/test_models.py +++ /dev/null @@ -1,271 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Integration tests for server/api/v1/models.py - -Tests the model configuration endpoints through the full API stack. -These endpoints require authentication. -""" - - -class TestAuthentication: - """Integration tests for authentication on model endpoints.""" - - def test_models_list_requires_auth(self, client): - """GET /v1/models should require authentication.""" - response = client.get("/v1/models") - - assert response.status_code == 401 - - def test_models_list_rejects_invalid_token(self, client, auth_headers): - """GET /v1/models should reject invalid tokens.""" - response = client.get("/v1/models", headers=auth_headers["invalid_auth"]) - - assert response.status_code == 401 - - def test_models_list_accepts_valid_token(self, client, auth_headers): - """GET /v1/models should accept valid tokens.""" - response = client.get("/v1/models", headers=auth_headers["valid_auth"]) - - assert response.status_code == 200 - - -class TestModelsList: - """Integration tests for the models list endpoint.""" - - def test_models_list_returns_list(self, client, auth_headers): - """GET /v1/models should return a list of models.""" - response = client.get("/v1/models", headers=auth_headers["valid_auth"]) - - assert response.status_code == 200 - data = response.json() - assert isinstance(data, list) - - def test_models_list_returns_enabled_only_by_default(self, client, auth_headers): - """GET /v1/models should return only enabled models by default.""" - response = client.get("/v1/models", headers=auth_headers["valid_auth"]) - - data = response.json() - for model in data: - assert model["enabled"] is True - - def test_models_list_with_include_disabled(self, client, auth_headers): - """GET /v1/models?include_disabled=true should include disabled models.""" - response = client.get( - "/v1/models", - headers=auth_headers["valid_auth"], - params={"include_disabled": True}, - ) - - assert response.status_code == 200 - data = response.json() - # Should have at least some models (bootstrap loads defaults) - assert isinstance(data, list) - - def test_models_list_filter_by_type_ll(self, client, auth_headers): - """GET /v1/models?model_type=ll should return only LL models.""" - response = client.get( - "/v1/models", - headers=auth_headers["valid_auth"], - params={"model_type": "ll", "include_disabled": True}, - ) - - assert response.status_code == 200 - data = response.json() - for model in data: - assert model["type"] == "ll" - - def test_models_list_filter_by_type_embed(self, client, auth_headers): - """GET /v1/models?model_type=embed should return only embed models.""" - response = client.get( - "/v1/models", - headers=auth_headers["valid_auth"], - params={"model_type": "embed", "include_disabled": True}, - ) - - assert response.status_code == 200 - data = response.json() - for model in data: - assert model["type"] == "embed" - - -class TestModelsSupported: - """Integration tests for the supported models endpoint.""" - - def test_models_supported_returns_list(self, client, auth_headers): - """GET /v1/models/supported should return supported providers.""" - response = client.get("/v1/models/supported", headers=auth_headers["valid_auth"]) - - assert response.status_code == 200 - data = response.json() - assert isinstance(data, list) - - def test_models_supported_filter_by_provider(self, client, auth_headers): - """GET /v1/models/supported?model_provider=openai should filter by provider.""" - response = client.get( - "/v1/models/supported", - headers=auth_headers["valid_auth"], - params={"model_provider": "openai"}, - ) - - assert response.status_code == 200 - data = response.json() - for item in data: - assert item.get("provider") == "openai" - - def test_models_supported_filter_by_type(self, client, auth_headers): - """GET /v1/models/supported?model_type=ll should filter by type.""" - response = client.get( - "/v1/models/supported", - headers=auth_headers["valid_auth"], - params={"model_type": "ll"}, - ) - - assert response.status_code == 200 - data = response.json() - # Response is a list of provider objects with provider and models keys - assert isinstance(data, list) - # Each item should have provider and models keys - for item in data: - assert "provider" in item - assert "models" in item - - -class TestModelsGet: - """Integration tests for the single model get endpoint.""" - - def test_models_get_requires_auth(self, client): - """GET /v1/models/{provider}/{id} should require authentication.""" - response = client.get("/v1/models/openai/gpt-4o-mini") - - assert response.status_code == 401 - - def test_models_get_returns_404_for_unknown(self, client, auth_headers): - """GET /v1/models/{provider}/{id} should return 404 for unknown model.""" - response = client.get( - "/v1/models/nonexistent/nonexistent-model", - headers=auth_headers["valid_auth"], - ) - - assert response.status_code == 404 - - def test_models_get_returns_model(self, client, auth_headers, model_objects_manager, make_model): - """GET /v1/models/{provider}/{id} should return the specified model.""" - # Add a test model - test_model = make_model(id="integration-test-model") - model_objects_manager.append(test_model) - - response = client.get( - "/v1/models/openai/integration-test-model", - headers=auth_headers["valid_auth"], - ) - - assert response.status_code == 200 - data = response.json() - assert data["id"] == "integration-test-model" - assert data["provider"] == "openai" - - -class TestModelsCreate: - """Integration tests for the model create endpoint.""" - - def test_models_create_requires_auth(self, client): - """POST /v1/models should require authentication.""" - response = client.post( - "/v1/models", - json={"id": "test-model", "type": "ll", "provider": "openai", "enabled": True}, - ) - - assert response.status_code == 401 - - def test_models_create_success(self, client, auth_headers, model_objects_manager): - """POST /v1/models should create a new model.""" - # pylint: disable=unused-argument - response = client.post( - "/v1/models", - headers=auth_headers["valid_auth"], - json={"id": "new-test-model", "type": "ll", "provider": "openai", "enabled": True}, - ) - - assert response.status_code == 201 - data = response.json() - assert data["id"] == "new-test-model" - assert data["provider"] == "openai" - - def test_models_create_returns_409_for_duplicate(self, client, auth_headers, model_objects_manager, make_model): - """POST /v1/models should return 409 for duplicate model.""" - # Add existing model - existing_model = make_model(id="duplicate-model") - model_objects_manager.append(existing_model) - - response = client.post( - "/v1/models", - headers=auth_headers["valid_auth"], - json={"id": "duplicate-model", "type": "ll", "provider": "openai", "enabled": True}, - ) - - assert response.status_code == 409 - - -class TestModelsUpdate: - """Integration tests for the model update endpoint.""" - - def test_models_update_requires_auth(self, client): - """PATCH /v1/models/{provider}/{id} should require authentication.""" - response = client.patch( - "/v1/models/openai/test-model", - json={"id": "test-model", "type": "ll", "provider": "openai", "enabled": False}, - ) - - assert response.status_code == 401 - - def test_models_update_returns_404_for_unknown(self, client, auth_headers): - """PATCH /v1/models/{provider}/{id} should return 404 for unknown model.""" - response = client.patch( - "/v1/models/nonexistent/nonexistent-model", - headers=auth_headers["valid_auth"], - json={"id": "nonexistent-model", "type": "ll", "provider": "nonexistent", "enabled": False}, - ) - - assert response.status_code == 404 - - def test_models_update_success(self, client, auth_headers, model_objects_manager, make_model): - """PATCH /v1/models/{provider}/{id} should update the model.""" - # Add a test model - test_model = make_model(id="update-test-model") - model_objects_manager.append(test_model) - - response = client.patch( - "/v1/models/openai/update-test-model", - headers=auth_headers["valid_auth"], - json={"id": "update-test-model", "type": "ll", "provider": "openai", "enabled": False}, - ) - - assert response.status_code == 200 - data = response.json() - assert data["enabled"] is False - - -class TestModelsDelete: - """Integration tests for the model delete endpoint.""" - - def test_models_delete_requires_auth(self, client): - """DELETE /v1/models/{provider}/{id} should require authentication.""" - response = client.delete("/v1/models/openai/test-model") - - assert response.status_code == 401 - - def test_models_delete_success(self, client, auth_headers, model_objects_manager, make_model): - """DELETE /v1/models/{provider}/{id} should delete the model.""" - # Add a test model to delete - test_model = make_model(id="delete-test-model") - model_objects_manager.append(test_model) - - response = client.delete( - "/v1/models/openai/delete-test-model", - headers=auth_headers["valid_auth"], - ) - - assert response.status_code == 200 - assert "deleted" in response.json()["message"].lower() diff --git a/test/integration/server/api/v1/test_oci.py b/test/integration/server/api/v1/test_oci.py deleted file mode 100644 index aeed656b..00000000 --- a/test/integration/server/api/v1/test_oci.py +++ /dev/null @@ -1,224 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Integration tests for server/api/v1/oci.py - -Tests the OCI configuration endpoints through the full API stack. -These endpoints require authentication. - -Note: Most OCI operations require valid OCI credentials. Tests without -real OCI credentials will verify endpoint availability and authentication. -""" - - -class TestOciList: - """Integration tests for the OCI list endpoint.""" - - def test_oci_list_requires_auth(self, client): - """GET /v1/oci should require authentication.""" - response = client.get("/v1/oci") - - assert response.status_code == 401 - - def test_oci_list_rejects_invalid_token(self, client, auth_headers): - """GET /v1/oci should reject invalid tokens.""" - response = client.get("/v1/oci", headers=auth_headers["invalid_auth"]) - - assert response.status_code == 401 - - def test_oci_list_accepts_valid_token(self, client, auth_headers): - """GET /v1/oci should accept valid tokens.""" - response = client.get("/v1/oci", headers=auth_headers["valid_auth"]) - - # May return 200 (with configs) or 404 (no configs) - assert response.status_code in [200, 404] - - def test_oci_list_returns_list_or_404(self, client, auth_headers): - """GET /v1/oci should return a list of OCI configs or 404 if none.""" - response = client.get("/v1/oci", headers=auth_headers["valid_auth"]) - - if response.status_code == 200: - data = response.json() - assert isinstance(data, list) - else: - assert response.status_code == 404 - - -class TestOciGet: - """Integration tests for the single OCI profile get endpoint.""" - - def test_oci_get_requires_auth(self, client): - """GET /v1/oci/{auth_profile} should require authentication.""" - response = client.get("/v1/oci/DEFAULT") - - assert response.status_code == 401 - - def test_oci_get_returns_404_for_unknown(self, client, auth_headers): - """GET /v1/oci/{auth_profile} should return 404 for unknown profile.""" - response = client.get( - "/v1/oci/NONEXISTENT_PROFILE", - headers=auth_headers["valid_auth"], - ) - - assert response.status_code == 404 - - -class TestOciRegions: - """Integration tests for the OCI regions endpoint.""" - - def test_oci_regions_requires_auth(self, client): - """GET /v1/oci/regions/{auth_profile} should require authentication.""" - response = client.get("/v1/oci/regions/DEFAULT") - - assert response.status_code == 401 - - def test_oci_regions_returns_404_for_unknown_profile(self, client, auth_headers): - """GET /v1/oci/regions/{auth_profile} should return 404 for unknown profile.""" - response = client.get( - "/v1/oci/regions/NONEXISTENT_PROFILE", - headers=auth_headers["valid_auth"], - ) - - assert response.status_code == 404 - - -class TestOciGenai: - """Integration tests for the OCI GenAI models endpoint.""" - - def test_oci_genai_requires_auth(self, client): - """GET /v1/oci/genai/{auth_profile} should require authentication.""" - response = client.get("/v1/oci/genai/DEFAULT") - - assert response.status_code == 401 - - def test_oci_genai_returns_404_for_unknown_profile(self, client, auth_headers): - """GET /v1/oci/genai/{auth_profile} should return 404 for unknown profile.""" - response = client.get( - "/v1/oci/genai/NONEXISTENT_PROFILE", - headers=auth_headers["valid_auth"], - ) - - assert response.status_code == 404 - - -class TestOciCompartments: - """Integration tests for the OCI compartments endpoint.""" - - def test_oci_compartments_requires_auth(self, client): - """GET /v1/oci/compartments/{auth_profile} should require authentication.""" - response = client.get("/v1/oci/compartments/DEFAULT") - - assert response.status_code == 401 - - def test_oci_compartments_returns_404_for_unknown_profile(self, client, auth_headers): - """GET /v1/oci/compartments/{auth_profile} should return 404 for unknown profile.""" - response = client.get( - "/v1/oci/compartments/NONEXISTENT_PROFILE", - headers=auth_headers["valid_auth"], - ) - - assert response.status_code == 404 - - -class TestOciBuckets: - """Integration tests for the OCI buckets endpoint.""" - - def test_oci_buckets_requires_auth(self, client): - """GET /v1/oci/buckets/{compartment_ocid}/{auth_profile} should require authentication.""" - response = client.get("/v1/oci/buckets/ocid1.compartment.oc1..test/DEFAULT") - - assert response.status_code == 401 - - def test_oci_buckets_returns_404_for_unknown_profile(self, client, auth_headers): - """GET /v1/oci/buckets/{compartment_ocid}/{auth_profile} should return 404 for unknown profile.""" - response = client.get( - "/v1/oci/buckets/ocid1.compartment.oc1..test/NONEXISTENT_PROFILE", - headers=auth_headers["valid_auth"], - ) - - assert response.status_code == 404 - - -class TestOciObjects: - """Integration tests for the OCI bucket objects endpoint.""" - - def test_oci_objects_requires_auth(self, client): - """GET /v1/oci/objects/{bucket_name}/{auth_profile} should require authentication.""" - response = client.get("/v1/oci/objects/test-bucket/DEFAULT") - - assert response.status_code == 401 - - def test_oci_objects_returns_404_for_unknown_profile(self, client, auth_headers): - """GET /v1/oci/objects/{bucket_name}/{auth_profile} should return 404 for unknown profile.""" - response = client.get( - "/v1/oci/objects/test-bucket/NONEXISTENT_PROFILE", - headers=auth_headers["valid_auth"], - ) - - assert response.status_code == 404 - - -class TestOciUpdate: - """Integration tests for the OCI profile update endpoint.""" - - def test_oci_update_requires_auth(self, client): - """PATCH /v1/oci/{auth_profile} should require authentication.""" - response = client.patch( - "/v1/oci/DEFAULT", - json={"auth_profile": "DEFAULT", "genai_region": "us-ashburn-1"}, - ) - - assert response.status_code == 401 - - def test_oci_update_returns_404_for_unknown_profile(self, client, auth_headers): - """PATCH /v1/oci/{auth_profile} should return 404 for unknown profile.""" - response = client.patch( - "/v1/oci/NONEXISTENT_PROFILE", - headers=auth_headers["valid_auth"], - json={"auth_profile": "NONEXISTENT_PROFILE", "genai_region": "us-ashburn-1"}, - ) - - assert response.status_code == 404 - - -class TestOciDownloadObjects: - """Integration tests for the OCI download objects endpoint.""" - - def test_oci_download_requires_auth(self, client): - """POST /v1/oci/objects/download/{bucket_name}/{auth_profile} should require authentication.""" - response = client.post( - "/v1/oci/objects/download/test-bucket/DEFAULT", - json=["file1.txt"], - ) - - assert response.status_code == 401 - - def test_oci_download_returns_404_for_unknown_profile(self, client, auth_headers): - """POST /v1/oci/objects/download/{bucket_name}/{auth_profile} should return 404 for unknown profile.""" - response = client.post( - "/v1/oci/objects/download/test-bucket/NONEXISTENT_PROFILE", - headers=auth_headers["valid_auth"], - json=["file1.txt"], - ) - - assert response.status_code == 404 - - -class TestOciCreateGenaiModels: - """Integration tests for the OCI create GenAI models endpoint.""" - - def test_oci_create_genai_requires_auth(self, client): - """POST /v1/oci/genai/{auth_profile} should require authentication.""" - response = client.post("/v1/oci/genai/DEFAULT") - - assert response.status_code == 401 - - def test_oci_create_genai_returns_404_for_unknown_profile(self, client, auth_headers): - """POST /v1/oci/genai/{auth_profile} should return 404 for unknown profile.""" - response = client.post( - "/v1/oci/genai/NONEXISTENT_PROFILE", - headers=auth_headers["valid_auth"], - ) - - assert response.status_code == 404 diff --git a/test/integration/server/api/v1/test_probes.py b/test/integration/server/api/v1/test_probes.py deleted file mode 100644 index 9d0401e8..00000000 --- a/test/integration/server/api/v1/test_probes.py +++ /dev/null @@ -1,74 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Integration tests for server/api/v1/probes.py - -Tests the Kubernetes probe endpoints (liveness, readiness, MCP health). -These endpoints do not require authentication. -""" - - -class TestLivenessProbe: - """Integration tests for the liveness probe endpoint.""" - - def test_liveness_returns_200(self, client): - """GET /v1/liveness should return 200 with status alive.""" - response = client.get("/v1/liveness") - - assert response.status_code == 200 - assert response.json() == {"status": "alive"} - - def test_liveness_no_auth_required(self, client): - """GET /v1/liveness should not require authentication.""" - # No auth headers provided - response = client.get("/v1/liveness") - - assert response.status_code == 200 - - -class TestReadinessProbe: - """Integration tests for the readiness probe endpoint.""" - - def test_readiness_returns_200(self, client): - """GET /v1/readiness should return 200 with status ready.""" - response = client.get("/v1/readiness") - - assert response.status_code == 200 - assert response.json() == {"status": "ready"} - - def test_readiness_no_auth_required(self, client): - """GET /v1/readiness should not require authentication.""" - response = client.get("/v1/readiness") - - assert response.status_code == 200 - - -class TestMcpHealthz: - """Integration tests for the MCP health check endpoint.""" - - def test_mcp_healthz_returns_200(self, client): - """GET /v1/mcp/healthz should return 200 with MCP status.""" - response = client.get("/v1/mcp/healthz") - - assert response.status_code == 200 - data = response.json() - assert data["status"] == "ready" - assert "name" in data - assert "version" in data - assert "available_tools" in data - - def test_mcp_healthz_no_auth_required(self, client): - """GET /v1/mcp/healthz should not require authentication.""" - response = client.get("/v1/mcp/healthz") - - assert response.status_code == 200 - - def test_mcp_healthz_returns_server_info(self, client): - """GET /v1/mcp/healthz should return MCP server information.""" - response = client.get("/v1/mcp/healthz") - - data = response.json() - assert data["name"] == "Oracle AI Optimizer and Toolkit MCP Server" - assert isinstance(data["available_tools"], int) - assert data["available_tools"] >= 0 diff --git a/test/integration/server/api/v1/test_settings.py b/test/integration/server/api/v1/test_settings.py deleted file mode 100644 index 15374061..00000000 --- a/test/integration/server/api/v1/test_settings.py +++ /dev/null @@ -1,307 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Integration tests for server/api/v1/settings.py - -Tests the settings configuration endpoints through the full API stack. -These endpoints require authentication. -""" - -import json -from io import BytesIO - - -class TestAuthentication: - """Integration tests for authentication on settings endpoints.""" - - def test_settings_get_requires_auth(self, client): - """GET /v1/settings should require authentication.""" - response = client.get("/v1/settings", params={"client": "test"}) - - assert response.status_code == 401 - - def test_settings_get_rejects_invalid_token(self, client, auth_headers): - """GET /v1/settings should reject invalid tokens.""" - response = client.get( - "/v1/settings", - headers=auth_headers["invalid_auth"], - params={"client": "test"}, - ) - - assert response.status_code == 401 - - def test_settings_get_accepts_valid_token(self, client, auth_headers): - """GET /v1/settings should accept valid tokens.""" - response = client.get( - "/v1/settings", - headers=auth_headers["valid_auth"], - params={"client": "server"}, # Use existing client - ) - - assert response.status_code == 200 - - -class TestSettingsGet: - """Integration tests for the settings get endpoint.""" - - def test_settings_get_returns_settings(self, client, auth_headers): - """GET /v1/settings should return settings for existing client.""" - response = client.get( - "/v1/settings", - headers=auth_headers["valid_auth"], - params={"client": "server"}, - ) - - assert response.status_code == 200 - data = response.json() - assert "client" in data - assert data["client"] == "server" - - def test_settings_get_returns_404_for_unknown_client(self, client, auth_headers): - """GET /v1/settings should return 404 for unknown client.""" - response = client.get( - "/v1/settings", - headers=auth_headers["valid_auth"], - params={"client": "nonexistent_client_xyz"}, - ) - - assert response.status_code == 404 - - def test_settings_get_full_config(self, client, auth_headers): - """GET /v1/settings?full_config=true should return full configuration.""" - response = client.get( - "/v1/settings", - headers=auth_headers["valid_auth"], - params={"client": "server", "full_config": True}, - ) - - assert response.status_code == 200 - data = response.json() - # Full config includes client_settings and all config arrays - assert "client_settings" in data - assert "database_configs" in data - assert "model_configs" in data - assert "oci_configs" in data - assert "prompt_configs" in data - - def test_settings_get_with_sensitive(self, client, auth_headers): - """GET /v1/settings?incl_sensitive=true should include sensitive fields.""" - response = client.get( - "/v1/settings", - headers=auth_headers["valid_auth"], - params={"client": "server", "full_config": True, "incl_sensitive": True}, - ) - - assert response.status_code == 200 - # Response should include sensitive fields (passwords) - # Exact fields depend on what's configured - - -class TestSettingsCreate: - """Integration tests for the settings create endpoint.""" - - def test_settings_create_requires_auth(self, client): - """POST /v1/settings should require authentication.""" - response = client.post("/v1/settings", params={"client": "new_test_client"}) - - assert response.status_code == 401 - - def test_settings_create_success(self, client, auth_headers, settings_objects_manager): - """POST /v1/settings should create new client settings.""" - # pylint: disable=unused-argument - response = client.post( - "/v1/settings", - headers=auth_headers["valid_auth"], - params={"client": "integration_new_client"}, - ) - - assert response.status_code == 200 - data = response.json() - assert data["client"] == "integration_new_client" - - def test_settings_create_returns_409_for_existing(self, client, auth_headers): - """POST /v1/settings should return 409 if client already exists.""" - # "server" client is created by bootstrap - response = client.post( - "/v1/settings", - headers=auth_headers["valid_auth"], - params={"client": "server"}, - ) - - assert response.status_code == 409 - - -class TestSettingsUpdate: - """Integration tests for the settings update endpoint.""" - - def test_settings_update_requires_auth(self, client): - """PATCH /v1/settings should require authentication.""" - response = client.patch( - "/v1/settings", - params={"client": "server"}, - json={"client": "server"}, - ) - - assert response.status_code == 401 - - def test_settings_update_returns_404_for_unknown(self, client, auth_headers): - """PATCH /v1/settings should return 404 for unknown client.""" - response = client.patch( - "/v1/settings", - headers=auth_headers["valid_auth"], - params={"client": "nonexistent_client_xyz"}, - json={"client": "nonexistent_client_xyz"}, - ) - - assert response.status_code == 404 - - def test_settings_update_success(self, client, auth_headers, settings_objects_manager): - """PATCH /v1/settings should update client settings.""" - # pylint: disable=unused-argument - # First create a client to update - client.post( - "/v1/settings", - headers=auth_headers["valid_auth"], - params={"client": "update_test_client"}, - ) - - # Now update it - response = client.patch( - "/v1/settings", - headers=auth_headers["valid_auth"], - params={"client": "update_test_client"}, - json={ - "client": "update_test_client", - "ll_model": { - "model": "gpt-4o", - "temperature": 0.5, - "max_tokens": 2048, - "chat_history": False, - }, - }, - ) - - assert response.status_code == 200 - data = response.json() - assert data["ll_model"]["temperature"] == 0.5 - - -class TestSettingsLoadFromFile: - """Integration tests for the settings load from file endpoint.""" - - def test_load_from_file_requires_auth(self, client): - """POST /v1/settings/load/file should require authentication.""" - response = client.post( - "/v1/settings/load/file", - params={"client": "test"}, - files={"file": ("test.json", b"{}", "application/json")}, - ) - - assert response.status_code == 401 - - def test_load_from_file_rejects_non_json_extension(self, client, auth_headers, settings_objects_manager): - """POST /v1/settings/load/file should reject files without .json extension. - - Note: Current implementation returns 500 due to HTTPException being caught - by generic Exception handler. This documents actual behavior. - """ - # pylint: disable=unused-argument - response = client.post( - "/v1/settings/load/file", - headers=auth_headers["valid_auth"], - params={"client": "file_test_client"}, - files={"file": ("test.txt", b"{}", "text/plain")}, - ) - - # Current behavior returns 500 (HTTPException caught by generic handler) - # Ideally should be 400, but documenting actual behavior - assert response.status_code == 500 - assert "Only JSON files are supported" in response.json()["detail"] - - def test_load_from_file_rejects_invalid_json_content(self, client, auth_headers, settings_objects_manager): - """POST /v1/settings/load/file should reject invalid JSON content.""" - # pylint: disable=unused-argument - response = client.post( - "/v1/settings/load/file", - headers=auth_headers["valid_auth"], - params={"client": "file_invalid_content"}, - files={"file": ("test.json", b"not valid json", "application/json")}, - ) - - # Invalid JSON content returns 400 - assert response.status_code == 400 - - def test_load_from_file_success(self, client, auth_headers, settings_objects_manager): - """POST /v1/settings/load/file should load configuration from JSON file.""" - # pylint: disable=unused-argument - config_data = { - "client_settings": { - "client": "file_load_client", - "ll_model": { - "model": "gpt-4o-mini", - "temperature": 0.8, - "max_tokens": 1000, - "chat_history": True, - }, - }, - "database_configs": [], - "model_configs": [], - "oci_configs": [], - "prompt_configs": [], - } - - file_content = json.dumps(config_data).encode("utf-8") - - response = client.post( - "/v1/settings/load/file", - headers=auth_headers["valid_auth"], - params={"client": "file_load_client"}, - files={"file": ("config.json", BytesIO(file_content), "application/json")}, - ) - - assert response.status_code == 200 - assert "loaded successfully" in response.json()["message"].lower() - - -class TestSettingsLoadFromJson: - """Integration tests for the settings load from JSON endpoint.""" - - def test_load_from_json_requires_auth(self, client): - """POST /v1/settings/load/json should require authentication.""" - response = client.post( - "/v1/settings/load/json", - params={"client": "test"}, - json={"client_settings": {"client": "test"}}, - ) - - assert response.status_code == 401 - - def test_load_from_json_success(self, client, auth_headers, settings_objects_manager): - """POST /v1/settings/load/json should load configuration from JSON payload.""" - # pylint: disable=unused-argument - config_data = { - "client_settings": { - "client": "json_load_client", - "ll_model": { - "model": "gpt-4o-mini", - "temperature": 0.9, - "max_tokens": 500, - "chat_history": True, - }, - }, - "database_configs": [], - "model_configs": [], - "oci_configs": [], - "prompt_configs": [], - } - - response = client.post( - "/v1/settings/load/json", - headers=auth_headers["valid_auth"], - params={"client": "json_load_client"}, - json=config_data, - ) - - assert response.status_code == 200 - assert "loaded successfully" in response.json()["message"].lower() diff --git a/test/integration/server/bootstrap/__init__.py b/test/integration/server/bootstrap/__init__.py deleted file mode 100644 index 90dc5216..00000000 --- a/test/integration/server/bootstrap/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Bootstrap integration test package diff --git a/test/integration/server/bootstrap/conftest.py b/test/integration/server/bootstrap/conftest.py deleted file mode 100644 index 00848e66..00000000 --- a/test/integration/server/bootstrap/conftest.py +++ /dev/null @@ -1,147 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Pytest fixtures for server/bootstrap integration tests. - -Integration tests for bootstrap test the actual bootstrap process with real -file I/O, environment variables, and configuration loading. These tests -verify end-to-end behavior of the bootstrap system. -""" - -# pylint: disable=redefined-outer-name unused-import - -import json -import tempfile -from pathlib import Path - -# Re-export shared fixtures for pytest discovery -from test.shared_fixtures import ( - reset_config_store, - clean_env, - BOOTSTRAP_ENV_VARS, - DEFAULT_LL_MODEL_CONFIG, -) - -import pytest - -# Alias for backwards compatibility -clean_bootstrap_env = clean_env - - -@pytest.fixture -def temp_dir(): - """Create a temporary directory for test files.""" - with tempfile.TemporaryDirectory() as tmpdir: - yield Path(tmpdir) - - -@pytest.fixture -def make_config_file(temp_dir): - """Factory fixture to create real configuration JSON files.""" - - def _make_config_file( - filename: str = "configuration.json", - client_settings: dict = None, - database_configs: list = None, - model_configs: list = None, - oci_configs: list = None, - prompt_configs: list = None, - ): - config_data = { - "client_settings": client_settings or {"client": "test_client"}, - "database_configs": database_configs or [], - "model_configs": model_configs or [], - "oci_configs": oci_configs or [], - "prompt_configs": prompt_configs or [], - } - - file_path = temp_dir / filename - with open(file_path, "w", encoding="utf-8") as f: - json.dump(config_data, f, indent=2) - - return file_path - - return _make_config_file - - -@pytest.fixture -def make_oci_config_file(temp_dir): - """Factory fixture to create real OCI configuration files.""" - - def _make_oci_config_file( - filename: str = "config", - profiles: dict = None, - ): - """Create an OCI-style config file. - - Args: - filename: Name of the config file - profiles: Dict of profile_name -> dict of key-value pairs - e.g., {"DEFAULT": {"tenancy": "...", "region": "..."}} - """ - if profiles is None: - profiles = { - "DEFAULT": { - "tenancy": "ocid1.tenancy.oc1..testtenancy", - "region": "us-ashburn-1", - "fingerprint": "test:fingerprint", - } - } - - file_path = temp_dir / filename - with open(file_path, "w", encoding="utf-8") as f: - for profile_name, settings in profiles.items(): - f.write(f"[{profile_name}]\n") - for key, value in settings.items(): - f.write(f"{key}={value}\n") - f.write("\n") - - return file_path - - return _make_oci_config_file - - -@pytest.fixture -def sample_database_config(): - """Sample database configuration dict.""" - return { - "name": "INTEGRATION_DB", - "user": "integration_user", - "password": "integration_pass", - "dsn": "localhost:1521/INTPDB", - } - - -@pytest.fixture -def sample_model_config(): - """Sample model configuration dict.""" - return { - "id": "integration-model", - "type": "ll", - "provider": "openai", - "enabled": True, - "api_key": "test-api-key", - "api_base": "https://api.openai.com/v1", - "max_tokens": 4096, - } - - -@pytest.fixture -def sample_oci_config(): - """Sample OCI configuration dict.""" - return { - "auth_profile": "INTEGRATION", - "tenancy": "ocid1.tenancy.oc1..integration", - "region": "us-phoenix-1", - "fingerprint": "integration:fingerprint", - } - - -@pytest.fixture -def sample_settings_config(): - """Sample settings configuration dict.""" - return { - "client": "integration_client", - "ll_model": DEFAULT_LL_MODEL_CONFIG.copy(), - } diff --git a/test/integration/server/bootstrap/test_bootstrap_configfile.py b/test/integration/server/bootstrap/test_bootstrap_configfile.py deleted file mode 100644 index 48cc2943..00000000 --- a/test/integration/server/bootstrap/test_bootstrap_configfile.py +++ /dev/null @@ -1,245 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Integration tests for server/bootstrap/configfile.py - -Tests the ConfigStore class with real file I/O operations. -""" - -# pylint: disable=redefined-outer-name - -import json -import os -from pathlib import Path - -import pytest - -from server.bootstrap.configfile import config_file_path - - -class TestConfigStoreFileOperations: - """Integration tests for ConfigStore with real file operations.""" - - def test_load_valid_json_file(self, reset_config_store, make_config_file, sample_settings_config): - """ConfigStore should load a valid JSON configuration file.""" - config_path = make_config_file( - client_settings=sample_settings_config, - ) - - reset_config_store.load_from_file(config_path) - config = reset_config_store.get() - - assert config is not None - assert config.client_settings.client == "integration_client" - - def test_load_file_with_all_sections( - self, - reset_config_store, - make_config_file, - sample_settings_config, - sample_database_config, - sample_model_config, - sample_oci_config, - ): - """ConfigStore should load file with all configuration sections.""" - config_path = make_config_file( - client_settings=sample_settings_config, - database_configs=[sample_database_config], - model_configs=[sample_model_config], - oci_configs=[sample_oci_config], - ) - - reset_config_store.load_from_file(config_path) - config = reset_config_store.get() - - assert config is not None - assert len(config.database_configs) == 1 - assert config.database_configs[0].name == "INTEGRATION_DB" - assert len(config.model_configs) == 1 - assert config.model_configs[0].id == "integration-model" - assert len(config.oci_configs) == 1 - assert config.oci_configs[0].auth_profile == "INTEGRATION" - - def test_load_nonexistent_file_returns_none(self, reset_config_store, temp_dir): - """ConfigStore should handle nonexistent files gracefully.""" - nonexistent_path = temp_dir / "does_not_exist.json" - - reset_config_store.load_from_file(nonexistent_path) - config = reset_config_store.get() - - assert config is None - - def test_load_file_with_unicode_content(self, reset_config_store, temp_dir): - """ConfigStore should handle files with unicode content.""" - config_data = { - "client_settings": {"client": "unicode_test_客户端"}, - "database_configs": [], - "model_configs": [], - "oci_configs": [], - "prompt_configs": [], - } - - config_path = temp_dir / "unicode_config.json" - with open(config_path, "w", encoding="utf-8") as f: - json.dump(config_data, f, ensure_ascii=False) - - reset_config_store.load_from_file(config_path) - config = reset_config_store.get() - - assert config is not None - assert config.client_settings.client == "unicode_test_客户端" - - def test_load_file_with_nested_settings(self, reset_config_store, temp_dir): - """ConfigStore should handle deeply nested settings.""" - config_data = { - "client_settings": { - "client": "nested_test", - "ll_model": { - "model": "gpt-4o-mini", - "temperature": 0.5, - "max_tokens": 2048, - "chat_history": True, - }, - "vector_search": { - "discovery": True, - "rephrase": True, - "grade": True, - "top_k": 5, - }, - }, - "database_configs": [], - "model_configs": [], - "oci_configs": [], - "prompt_configs": [], - } - - config_path = temp_dir / "nested_config.json" - with open(config_path, "w", encoding="utf-8") as f: - json.dump(config_data, f) - - reset_config_store.load_from_file(config_path) - config = reset_config_store.get() - - assert config is not None - assert config.client_settings.ll_model.temperature == 0.5 - assert config.client_settings.vector_search.top_k == 5 - - def test_load_large_config_file(self, reset_config_store, temp_dir): - """ConfigStore should handle large configuration files.""" - # Create config with many database entries - database_configs = [ - { - "name": f"DB_{i}", - "user": f"user_{i}", - "password": f"pass_{i}", - "dsn": f"host{i}:1521/PDB{i}", - } - for i in range(50) - ] - - config_data = { - "client_settings": {"client": "large_test"}, - "database_configs": database_configs, - "model_configs": [], - "oci_configs": [], - "prompt_configs": [], - } - - config_path = temp_dir / "large_config.json" - with open(config_path, "w", encoding="utf-8") as f: - json.dump(config_data, f) - - reset_config_store.load_from_file(config_path) - config = reset_config_store.get() - - assert config is not None - assert len(config.database_configs) == 50 - - def test_load_file_preserves_field_types(self, reset_config_store, temp_dir): - """ConfigStore should preserve correct field types after loading.""" - config_data = { - "client_settings": { - "client": "type_test", - "ll_model": { - "model": "test-model", - "temperature": 0.7, # float - "max_tokens": 4096, # int - "chat_history": True, # bool - }, - }, - "database_configs": [], - "model_configs": [], - "oci_configs": [], - "prompt_configs": [], - } - - config_path = temp_dir / "types_config.json" - with open(config_path, "w", encoding="utf-8") as f: - json.dump(config_data, f) - - reset_config_store.load_from_file(config_path) - config = reset_config_store.get() - - assert isinstance(config.client_settings.ll_model.temperature, float) - assert isinstance(config.client_settings.ll_model.max_tokens, int) - assert isinstance(config.client_settings.ll_model.chat_history, bool) - - -class TestConfigStoreValidation: - """Integration tests for ConfigStore validation with real files.""" - - def test_load_file_validates_required_fields(self, reset_config_store, temp_dir): - """ConfigStore should validate required fields in config.""" - # Missing required 'client' field in client_settings - config_data = { - "client_settings": {}, # Missing 'client' - "database_configs": [], - "model_configs": [], - "oci_configs": [], - "prompt_configs": [], - } - - config_path = temp_dir / "invalid_config.json" - with open(config_path, "w", encoding="utf-8") as f: - json.dump(config_data, f) - - with pytest.raises(Exception): # Pydantic ValidationError - reset_config_store.load_from_file(config_path) - - def test_load_malformed_json_raises_error(self, reset_config_store, temp_dir): - """ConfigStore should raise error for malformed JSON.""" - config_path = temp_dir / "malformed.json" - with open(config_path, "w", encoding="utf-8") as f: - f.write("{ invalid json content }") - - with pytest.raises(json.JSONDecodeError): - reset_config_store.load_from_file(config_path) - - -class TestConfigFilePath: - """Integration tests for config_file_path function.""" - - def test_config_file_path_returns_valid_path(self): - """config_file_path should return a valid filesystem path.""" - path = config_file_path() - - assert path is not None - assert isinstance(path, str) - assert path.endswith("configuration.json") - - def test_config_file_path_parent_directory_structure(self): - """config_file_path should point to server/etc directory.""" - path = config_file_path() - path_obj = Path(path) - - # Parent should be 'etc' directory - assert path_obj.parent.name == "etc" - # Grandparent should be 'server' directory - assert path_obj.parent.parent.name == "server" - - def test_config_file_path_is_absolute(self): - """config_file_path should return an absolute path.""" - path = config_file_path() - - assert os.path.isabs(path) diff --git a/test/integration/server/bootstrap/test_bootstrap_databases.py b/test/integration/server/bootstrap/test_bootstrap_databases.py deleted file mode 100644 index 70336db5..00000000 --- a/test/integration/server/bootstrap/test_bootstrap_databases.py +++ /dev/null @@ -1,197 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Integration tests for server/bootstrap/databases.py - -Tests the database bootstrap process with real configuration files -and environment variables. -""" - -# pylint: disable=redefined-outer-name - -import os - -from test.shared_fixtures import ( - assert_database_list_valid, - assert_has_default_database, - get_database_by_name, -) - -import pytest - -from server.bootstrap import databases as databases_module - - -@pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") -class TestDatabasesBootstrapWithConfig: - """Integration tests for database bootstrap with configuration files.""" - - def test_bootstrap_returns_database_objects(self): - """databases.main() should return list of Database objects.""" - result = databases_module.main() - assert_database_list_valid(result) - - def test_bootstrap_creates_default_database(self): - """databases.main() should always create DEFAULT database.""" - result = databases_module.main() - assert_has_default_database(result) - - def test_bootstrap_with_config_file_databases(self, reset_config_store, make_config_file): - """databases.main() should load databases from config file.""" - config_path = make_config_file( - database_configs=[ - { - "name": "CONFIG_DB1", - "user": "config_user1", - "password": "config_pass1", - "dsn": "host1:1521/PDB1", - }, - { - "name": "CONFIG_DB2", - "user": "config_user2", - "password": "config_pass2", - "dsn": "host2:1521/PDB2", - }, - ], - ) - - reset_config_store.load_from_file(config_path) - result = databases_module.main() - - db_names = [db.name for db in result] - assert "CONFIG_DB1" in db_names - assert "CONFIG_DB2" in db_names - - def test_bootstrap_default_from_config_overridden_by_env(self, reset_config_store, make_config_file): - """databases.main() should override DEFAULT config values with env vars.""" - config_path = make_config_file( - database_configs=[ - { - "name": "DEFAULT", - "user": "config_user", - "password": "config_pass", - "dsn": "config_host:1521/CFGPDB", - }, - ], - ) - - os.environ["DB_USERNAME"] = "env_user" - os.environ["DB_PASSWORD"] = "env_password" - - try: - reset_config_store.load_from_file(config_path) - result = databases_module.main() - default_db = get_database_by_name(result, "DEFAULT") - assert default_db.user == "env_user" - assert default_db.password == "env_password" - assert default_db.dsn == "config_host:1521/CFGPDB" # DSN not in env, keep config value - finally: - del os.environ["DB_USERNAME"] - del os.environ["DB_PASSWORD"] - - def test_bootstrap_raises_on_duplicate_names(self, reset_config_store, make_config_file): - """databases.main() should raise error for duplicate database names.""" - config_path = make_config_file( - database_configs=[ - {"name": "DUP_DB", "user": "user1", "password": "pass1", "dsn": "dsn1"}, - {"name": "dup_db", "user": "user2", "password": "pass2", "dsn": "dsn2"}, - ], - ) - - reset_config_store.load_from_file(config_path) - - with pytest.raises(ValueError, match="Duplicate database name"): - databases_module.main() - - -@pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") -class TestDatabasesBootstrapWithEnvVars: - """Integration tests for database bootstrap with environment variables.""" - - def test_bootstrap_uses_env_vars_for_default(self): - """databases.main() should use env vars for DEFAULT when no config.""" - os.environ["DB_USERNAME"] = "env_user" - os.environ["DB_PASSWORD"] = "env_password" - os.environ["DB_DSN"] = "env_host:1521/ENVPDB" - - try: - result = databases_module.main() - default_db = get_database_by_name(result, "DEFAULT") - assert default_db.user == "env_user" - assert default_db.password == "env_password" - assert default_db.dsn == "env_host:1521/ENVPDB" - finally: - del os.environ["DB_USERNAME"] - del os.environ["DB_PASSWORD"] - del os.environ["DB_DSN"] - - def test_bootstrap_wallet_password_sets_wallet_location(self): - """databases.main() should set wallet_location when wallet_password present.""" - os.environ["DB_WALLET_PASSWORD"] = "wallet_secret" - os.environ["TNS_ADMIN"] = "/path/to/wallet" - - try: - result = databases_module.main() - default_db = get_database_by_name(result, "DEFAULT") - assert default_db.wallet_password == "wallet_secret" - assert default_db.wallet_location == "/path/to/wallet" - assert default_db.config_dir == "/path/to/wallet" - finally: - del os.environ["DB_WALLET_PASSWORD"] - del os.environ["TNS_ADMIN"] - - def test_bootstrap_tns_admin_default(self): - """databases.main() should use 'tns_admin' as default config_dir.""" - result = databases_module.main() - default_db = get_database_by_name(result, "DEFAULT") - assert default_db.config_dir == "tns_admin" - - -@pytest.mark.usefixtures("clean_bootstrap_env") -class TestDatabasesBootstrapPreservation: - """Integration tests for database bootstrap preserving non-DEFAULT databases.""" - - def test_bootstrap_preserves_non_default_databases(self, reset_config_store, make_config_file): - """databases.main() should not modify non-DEFAULT databases.""" - os.environ["DB_USERNAME"] = "should_not_apply" - - config_path = make_config_file( - database_configs=[ - { - "name": "CUSTOM_DB", - "user": "custom_user", - "password": "custom_pass", - "dsn": "custom:1521/CPDB", - }, - ], - ) - - try: - reset_config_store.load_from_file(config_path) - result = databases_module.main() - custom_db = get_database_by_name(result, "CUSTOM_DB") - assert custom_db.user == "custom_user" - assert custom_db.password == "custom_pass" - finally: - del os.environ["DB_USERNAME"] - - def test_bootstrap_creates_default_when_not_in_config(self, reset_config_store, make_config_file): - """databases.main() should create DEFAULT from env when not in config.""" - os.environ["DB_USERNAME"] = "env_default_user" - - config_path = make_config_file( - database_configs=[ - {"name": "OTHER_DB", "user": "other", "password": "other", "dsn": "other"}, - ], - ) - - try: - reset_config_store.load_from_file(config_path) - result = databases_module.main() - assert_has_default_database(result) - assert "OTHER_DB" in [d.name for d in result] - default_db = get_database_by_name(result, "DEFAULT") - assert default_db.user == "env_default_user" - finally: - del os.environ["DB_USERNAME"] diff --git a/test/integration/server/bootstrap/test_bootstrap_models.py b/test/integration/server/bootstrap/test_bootstrap_models.py deleted file mode 100644 index b39042ec..00000000 --- a/test/integration/server/bootstrap/test_bootstrap_models.py +++ /dev/null @@ -1,263 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Integration tests for server/bootstrap/models.py - -Tests the models bootstrap process with real configuration files -and environment variables. -""" - -# pylint: disable=redefined-outer-name - -import os -from unittest.mock import patch - -from test.shared_fixtures import assert_model_list_valid, get_model_by_id - -import pytest - -from server.bootstrap import models as models_module - - -@pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") -class TestModelsBootstrapBasic: - """Integration tests for basic models bootstrap functionality.""" - - def test_bootstrap_returns_model_objects(self): - """models.main() should return list of Model objects.""" - with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: - mock_accessible.return_value = (True, "OK") - result = models_module.main() - assert_model_list_valid(result) - - def test_bootstrap_includes_base_models(self): - """models.main() should include base model configurations.""" - with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: - mock_accessible.return_value = (True, "OK") - result = models_module.main() - - model_ids = [m.id for m in result] - # Check for some expected base models - assert "gpt-4o-mini" in model_ids - assert "command-r" in model_ids - - def test_bootstrap_includes_ll_and_embed_models(self): - """models.main() should include both LLM and embedding models.""" - with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: - mock_accessible.return_value = (True, "OK") - result = models_module.main() - - model_types = {m.type for m in result} - assert "ll" in model_types - assert "embed" in model_types - - -@pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") -class TestModelsBootstrapWithApiKeys: - """Integration tests for models bootstrap with API keys.""" - - def test_bootstrap_enables_models_with_openai_key(self): - """models.main() should enable OpenAI models when key is present.""" - os.environ["OPENAI_API_KEY"] = "test-openai-key" - - try: - with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: - mock_accessible.return_value = (True, "OK") - result = models_module.main() - openai_model = get_model_by_id(result, "gpt-4o-mini") - assert openai_model.enabled is True - assert openai_model.api_key == "test-openai-key" - finally: - del os.environ["OPENAI_API_KEY"] - - def test_bootstrap_enables_models_with_cohere_key(self): - """models.main() should enable Cohere models when key is present.""" - os.environ["COHERE_API_KEY"] = "test-cohere-key" - - try: - with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: - mock_accessible.return_value = (True, "OK") - result = models_module.main() - cohere_model = get_model_by_id(result, "command-r") - assert cohere_model.enabled is True - assert cohere_model.api_key == "test-cohere-key" - finally: - del os.environ["COHERE_API_KEY"] - - def test_bootstrap_disables_models_without_keys(self): - """models.main() should disable models when API keys are not present.""" - with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: - mock_accessible.return_value = (True, "OK") - result = models_module.main() - openai_model = get_model_by_id(result, "gpt-4o-mini") - assert openai_model.enabled is False # Without OPENAI_API_KEY - - -@pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") -class TestModelsBootstrapWithOnPremUrls: - """Integration tests for models bootstrap with on-prem URLs.""" - - def test_bootstrap_enables_ollama_with_url(self): - """models.main() should enable Ollama models when URL is set.""" - os.environ["ON_PREM_OLLAMA_URL"] = "http://localhost:11434" - - try: - with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: - mock_accessible.return_value = (True, "OK") - result = models_module.main() - ollama_model = get_model_by_id(result, "llama3.1") - assert ollama_model.enabled is True - assert ollama_model.api_base == "http://localhost:11434" - finally: - del os.environ["ON_PREM_OLLAMA_URL"] - - def test_bootstrap_checks_url_accessibility(self): - """models.main() should check URL accessibility for enabled models.""" - os.environ["ON_PREM_OLLAMA_URL"] = "http://localhost:11434" - - try: - with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: - mock_accessible.return_value = (False, "Connection refused") - result = models_module.main() - ollama_model = get_model_by_id(result, "llama3.1") - assert ollama_model.enabled is False # Should be disabled if URL not accessible - finally: - del os.environ["ON_PREM_OLLAMA_URL"] - - -@pytest.mark.usefixtures("clean_bootstrap_env") -class TestModelsBootstrapWithConfigStore: - """Integration tests for models bootstrap with ConfigStore configuration.""" - - def test_bootstrap_merges_config_store_models(self, reset_config_store, make_config_file): - """models.main() should merge models from ConfigStore.""" - config_path = make_config_file( - model_configs=[ - { - "id": "custom-model", - "type": "ll", - "provider": "custom", - "enabled": True, - "api_base": "https://custom.api/v1", - "api_key": "custom-key", - }, - ], - ) - - try: - reset_config_store.load_from_file(config_path) - - with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: - mock_accessible.return_value = (True, "OK") - result = models_module.main() - - model_ids = [m.id for m in result] - assert "custom-model" in model_ids - - custom_model = get_model_by_id(result, "custom-model") - assert custom_model.provider == "custom" - assert custom_model.api_base == "https://custom.api/v1" - finally: - pass - - def test_bootstrap_config_store_overrides_base_model(self, reset_config_store, make_config_file): - """models.main() should let ConfigStore override base model settings.""" - config_path = make_config_file( - model_configs=[ - { - "id": "gpt-4o-mini", - "type": "ll", - "provider": "openai", - "enabled": True, - "api_base": "https://api.openai.com/v1", - "api_key": "override-key", - "max_tokens": 9999, - }, - ], - ) - - try: - reset_config_store.load_from_file(config_path) - - with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: - mock_accessible.return_value = (True, "OK") - result = models_module.main() - openai_model = get_model_by_id(result, "gpt-4o-mini") - assert openai_model.api_key == "override-key" - assert openai_model.max_tokens == 9999 - finally: - pass - - -@pytest.mark.usefixtures("clean_bootstrap_env") -class TestModelsBootstrapDuplicateDetection: - """Integration tests for models bootstrap duplicate detection.""" - - def test_bootstrap_deduplicates_config_store_models(self, reset_config_store, make_config_file): - """models.main() should deduplicate models with same provider+id in ConfigStore. - - Note: ConfigStore models with the same (provider, id) key are deduplicated - during the merge process (dict keyed by tuple keeps last value). - This is different from base model duplicate detection which raises an error. - """ - # Create config with duplicate model (same provider + id) - config_path = make_config_file( - model_configs=[ - { - "id": "duplicate-model", - "type": "ll", - "provider": "test", - "api_base": "http://test1", - }, - { - "id": "duplicate-model", - "type": "ll", - "provider": "test", - "api_base": "http://test2", - }, - ], - ) - - reset_config_store.load_from_file(config_path) - - with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: - mock_accessible.return_value = (True, "OK") - result = models_module.main() - - # Should have only one model with the duplicate id (last one wins) - dup_models = [m for m in result if m.id == "duplicate-model"] - assert len(dup_models) == 1 - # The last entry in the config should win - assert dup_models[0].api_base == "http://test2" - - def test_bootstrap_allows_same_id_different_provider(self, reset_config_store, make_config_file): - """models.main() should allow same ID with different providers.""" - config_path = make_config_file( - model_configs=[ - { - "id": "shared-model-name", - "type": "ll", - "provider": "provider1", - "api_base": "http://provider1", - }, - { - "id": "shared-model-name", - "type": "ll", - "provider": "provider2", - "api_base": "http://provider2", - }, - ], - ) - - reset_config_store.load_from_file(config_path) - - with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: - mock_accessible.return_value = (True, "OK") - result = models_module.main() - - # Both should be present - shared_models = [m for m in result if m.id == "shared-model-name"] - assert len(shared_models) == 2 - providers = {m.provider for m in shared_models} - assert providers == {"provider1", "provider2"} diff --git a/test/integration/server/bootstrap/test_bootstrap_oci.py b/test/integration/server/bootstrap/test_bootstrap_oci.py deleted file mode 100644 index 4d4afd47..00000000 --- a/test/integration/server/bootstrap/test_bootstrap_oci.py +++ /dev/null @@ -1,246 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Integration tests for server/bootstrap/oci.py - -Tests the OCI bootstrap process with real configuration files -and environment variables. -""" - -# pylint: disable=redefined-outer-name - -import os - -import oci -import pytest - -from server.bootstrap import oci as oci_module -from common.schema import OracleCloudSettings - - -@pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") -class TestOciBootstrapWithEnvVars: - """Integration tests for OCI bootstrap with environment variables.""" - - def test_bootstrap_returns_oci_settings_objects(self): - """oci.main() should return list of OracleCloudSettings objects.""" - # Point to nonexistent OCI config to test env var path - os.environ["OCI_CLI_CONFIG_FILE"] = "/nonexistent/oci/config" - - try: - result = oci_module.main() - - assert isinstance(result, list) - assert all(isinstance(s, OracleCloudSettings) for s in result) - finally: - del os.environ["OCI_CLI_CONFIG_FILE"] - - def test_bootstrap_creates_default_profile(self): - """oci.main() should always create DEFAULT profile.""" - os.environ["OCI_CLI_CONFIG_FILE"] = "/nonexistent/oci/config" - - try: - result = oci_module.main() - - profile_names = [s.auth_profile for s in result] - assert oci.config.DEFAULT_PROFILE in profile_names - finally: - del os.environ["OCI_CLI_CONFIG_FILE"] - - def test_bootstrap_applies_tenancy_env_var(self): - """oci.main() should apply OCI_CLI_TENANCY to DEFAULT profile.""" - os.environ["OCI_CLI_CONFIG_FILE"] = "/nonexistent/oci/config" - os.environ["OCI_CLI_TENANCY"] = "ocid1.tenancy.oc1..envtenancy" - - try: - result = oci_module.main() - - default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) - assert default_profile.tenancy == "ocid1.tenancy.oc1..envtenancy" - finally: - del os.environ["OCI_CLI_CONFIG_FILE"] - del os.environ["OCI_CLI_TENANCY"] - - def test_bootstrap_applies_region_env_var(self): - """oci.main() should apply OCI_CLI_REGION to DEFAULT profile.""" - os.environ["OCI_CLI_CONFIG_FILE"] = "/nonexistent/oci/config" - os.environ["OCI_CLI_REGION"] = "us-chicago-1" - - try: - result = oci_module.main() - - default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) - assert default_profile.region == "us-chicago-1" - finally: - del os.environ["OCI_CLI_CONFIG_FILE"] - del os.environ["OCI_CLI_REGION"] - - def test_bootstrap_applies_genai_env_vars(self): - """oci.main() should apply GenAI environment variables.""" - os.environ["OCI_CLI_CONFIG_FILE"] = "/nonexistent/oci/config" - os.environ["OCI_GENAI_COMPARTMENT_ID"] = "ocid1.compartment.oc1..genaicomp" - os.environ["OCI_GENAI_REGION"] = "us-chicago-1" - - try: - result = oci_module.main() - - default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) - assert default_profile.genai_compartment_id == "ocid1.compartment.oc1..genaicomp" - assert default_profile.genai_region == "us-chicago-1" - finally: - del os.environ["OCI_CLI_CONFIG_FILE"] - del os.environ["OCI_GENAI_COMPARTMENT_ID"] - del os.environ["OCI_GENAI_REGION"] - - def test_bootstrap_explicit_auth_method(self): - """oci.main() should use OCI_CLI_AUTH when specified.""" - os.environ["OCI_CLI_CONFIG_FILE"] = "/nonexistent/oci/config" - os.environ["OCI_CLI_AUTH"] = "instance_principal" - - try: - result = oci_module.main() - - default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) - assert default_profile.authentication == "instance_principal" - finally: - del os.environ["OCI_CLI_CONFIG_FILE"] - del os.environ["OCI_CLI_AUTH"] - - def test_bootstrap_default_auth_is_api_key(self): - """oci.main() should default to api_key authentication.""" - os.environ["OCI_CLI_CONFIG_FILE"] = "/nonexistent/oci/config" - - try: - result = oci_module.main() - - default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) - assert default_profile.authentication == "api_key" - finally: - del os.environ["OCI_CLI_CONFIG_FILE"] - - -@pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") -class TestOciBootstrapWithConfigFile: - """Integration tests for OCI bootstrap with real OCI config files.""" - - def test_bootstrap_reads_oci_config_file(self, make_oci_config_file): - """oci.main() should read profiles from OCI config file.""" - config_path = make_oci_config_file( - profiles={ - "DEFAULT": { - "tenancy": "ocid1.tenancy.oc1..filetenancy", - "region": "us-ashburn-1", - "fingerprint": "file:fingerprint", - }, - } - ) - - os.environ["OCI_CLI_CONFIG_FILE"] = str(config_path) - - try: - result = oci_module.main() - - # Should have loaded the profile from file - profile_names = [s.auth_profile for s in result] - assert "DEFAULT" in profile_names - finally: - del os.environ["OCI_CLI_CONFIG_FILE"] - - def test_bootstrap_loads_multiple_profiles(self, make_oci_config_file): - """oci.main() should load multiple profiles from OCI config file.""" - config_path = make_oci_config_file( - profiles={ - "DEFAULT": { - "tenancy": "ocid1.tenancy.oc1..default", - "region": "us-ashburn-1", - "fingerprint": "default:fp", - }, - "PRODUCTION": { - "tenancy": "ocid1.tenancy.oc1..production", - "region": "us-phoenix-1", - "fingerprint": "prod:fp", - }, - } - ) - - os.environ["OCI_CLI_CONFIG_FILE"] = str(config_path) - - try: - result = oci_module.main() - - profile_names = [s.auth_profile for s in result] - assert "DEFAULT" in profile_names - assert "PRODUCTION" in profile_names - finally: - del os.environ["OCI_CLI_CONFIG_FILE"] - - -@pytest.mark.usefixtures("clean_bootstrap_env") -class TestOciBootstrapWithConfigStore: - """Integration tests for OCI bootstrap with ConfigStore configuration.""" - - def test_bootstrap_merges_config_store_profiles(self, reset_config_store, make_config_file): - """oci.main() should merge profiles from ConfigStore.""" - os.environ["OCI_CLI_CONFIG_FILE"] = "/nonexistent/oci/config" - - config_path = make_config_file( - oci_configs=[ - { - "auth_profile": "CONFIGSTORE_PROFILE", - "tenancy": "ocid1.tenancy.oc1..configstore", - "region": "us-sanjose-1", - "fingerprint": "cs:fingerprint", - }, - ], - ) - - try: - reset_config_store.load_from_file(config_path) - result = oci_module.main() - - profile_names = [s.auth_profile for s in result] - assert "CONFIGSTORE_PROFILE" in profile_names - - cs_profile = next(p for p in result if p.auth_profile == "CONFIGSTORE_PROFILE") - assert cs_profile.tenancy == "ocid1.tenancy.oc1..configstore" - assert cs_profile.region == "us-sanjose-1" - finally: - del os.environ["OCI_CLI_CONFIG_FILE"] - - def test_bootstrap_config_store_overrides_file_profile( - self, reset_config_store, make_config_file, make_oci_config_file - ): - """oci.main() should let ConfigStore override file profiles.""" - oci_config_path = make_oci_config_file( - profiles={ - "DEFAULT": { - "tenancy": "ocid1.tenancy.oc1..fromfile", - "region": "us-ashburn-1", - "fingerprint": "file:fp", - }, - } - ) - - config_path = make_config_file( - oci_configs=[ - { - "auth_profile": "DEFAULT", - "tenancy": "ocid1.tenancy.oc1..fromconfigstore", - "region": "us-phoenix-1", - "fingerprint": "cs:fp", - }, - ], - ) - - os.environ["OCI_CLI_CONFIG_FILE"] = str(oci_config_path) - - try: - reset_config_store.load_from_file(config_path) - result = oci_module.main() - - default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) - # ConfigStore should override file values - assert default_profile.tenancy == "ocid1.tenancy.oc1..fromconfigstore" - finally: - del os.environ["OCI_CLI_CONFIG_FILE"] diff --git a/test/integration/server/bootstrap/test_bootstrap_settings.py b/test/integration/server/bootstrap/test_bootstrap_settings.py deleted file mode 100644 index 1d71376c..00000000 --- a/test/integration/server/bootstrap/test_bootstrap_settings.py +++ /dev/null @@ -1,170 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Integration tests for server/bootstrap/settings.py - -Tests the settings bootstrap process with real configuration files. -""" - -# pylint: disable=redefined-outer-name protected-access too-few-public-methods - -import pytest - -from server.bootstrap import settings as settings_module -from common.schema import Settings - - -@pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") -class TestSettingsBootstrapWithConfig: - """Integration tests for settings bootstrap with configuration files.""" - - def test_bootstrap_creates_default_and_server_clients(self): - """settings.main() should always create default and server clients.""" - result = settings_module.main() - - assert len(result) == 2 - client_names = [s.client for s in result] - assert "default" in client_names - assert "server" in client_names - - def test_bootstrap_returns_settings_objects(self): - """settings.main() should return list of Settings objects.""" - result = settings_module.main() - - assert all(isinstance(s, Settings) for s in result) - - def test_bootstrap_with_config_file(self, reset_config_store, make_config_file): - """settings.main() should use settings from config file.""" - config_path = make_config_file( - client_settings={ - "client": "config_client", - "ll_model": { - "model": "custom-model", - "temperature": 0.9, - "max_tokens": 8192, - "chat_history": False, - }, - }, - ) - - reset_config_store.load_from_file(config_path) - result = settings_module.main() - - # All clients should inherit config file settings - for s in result: - assert s.ll_model.model == "custom-model" - assert s.ll_model.temperature == 0.9 - assert s.ll_model.max_tokens == 8192 - assert s.ll_model.chat_history is False - - def test_bootstrap_overrides_client_names(self, reset_config_store, make_config_file): - """settings.main() should override client field to default/server.""" - config_path = make_config_file( - client_settings={ - "client": "original_client_name", - }, - ) - - reset_config_store.load_from_file(config_path) - result = settings_module.main() - - client_names = [s.client for s in result] - assert "original_client_name" not in client_names - assert "default" in client_names - assert "server" in client_names - - def test_bootstrap_with_vector_search_settings(self, reset_config_store, make_config_file): - """settings.main() should load vector search settings from config.""" - config_path = make_config_file( - client_settings={ - "client": "vs_client", - "vector_search": { - "discovery": False, - "rephrase": False, - "grade": True, - "top_k": 10, - "search_type": "Similarity", - }, - }, - ) - - reset_config_store.load_from_file(config_path) - result = settings_module.main() - - for s in result: - assert s.vector_search.discovery is False - assert s.vector_search.rephrase is False - assert s.vector_search.grade is True - assert s.vector_search.top_k == 10 - - def test_bootstrap_with_oci_settings(self, reset_config_store, make_config_file): - """settings.main() should load OCI settings from config.""" - config_path = make_config_file( - client_settings={ - "client": "oci_client", - "oci": { - "auth_profile": "CUSTOM_PROFILE", - }, - }, - ) - - reset_config_store.load_from_file(config_path) - result = settings_module.main() - - for s in result: - assert s.oci.auth_profile == "CUSTOM_PROFILE" - - def test_bootstrap_with_database_settings(self, reset_config_store, make_config_file): - """settings.main() should load database settings from config.""" - config_path = make_config_file( - client_settings={ - "client": "db_client", - "database": { - "alias": "CUSTOM_DB", - }, - }, - ) - - reset_config_store.load_from_file(config_path) - result = settings_module.main() - - for s in result: - assert s.database.alias == "CUSTOM_DB" - - -@pytest.mark.usefixtures("clean_bootstrap_env") -class TestSettingsBootstrapWithoutConfig: - """Integration tests for settings bootstrap without configuration.""" - - def test_bootstrap_without_config_uses_defaults(self, reset_config_store): - """settings.main() should use default values without config file.""" - # Ensure no config is loaded - assert reset_config_store.get() is None - - result = settings_module.main() - - assert len(result) == 2 - # Should have default Settings values - for s in result: - assert isinstance(s, Settings) - # Default values from Settings model - assert s.oci.auth_profile == "DEFAULT" - assert s.database.alias == "DEFAULT" - - -@pytest.mark.usefixtures("clean_bootstrap_env") -class TestSettingsBootstrapIdempotency: - """Integration tests for settings bootstrap idempotency.""" - - def test_bootstrap_produces_consistent_results(self, reset_config_store): - """settings.main() should produce consistent results on multiple calls.""" - result1 = settings_module.main() - - # Reset and call again - reset_config_store._config = None - result2 = settings_module.main() - - assert len(result1) == len(result2) - for s1, s2 in zip(result1, result2): - assert s1.client == s2.client diff --git a/test/shared_fixtures.py b/test/shared_fixtures.py deleted file mode 100644 index 22953e39..00000000 --- a/test/shared_fixtures.py +++ /dev/null @@ -1,330 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Shared pytest fixtures for unit and integration tests. - -This module contains common fixture factories and utilities that are shared -across multiple test conftest files to avoid code duplication. -""" - -# pylint: disable=redefined-outer-name - -import json -import os -import tempfile -from pathlib import Path - -import pytest - -from common.schema import ( - Configuration, - Database, - Model, - OracleCloudSettings, - Settings, - LargeLanguageSettings, -) -from server.bootstrap.configfile import ConfigStore - - -# Default test model settings - shared across test fixtures -DEFAULT_LL_MODEL_CONFIG = { - "model": "gpt-4o-mini", - "temperature": 0.7, - "max_tokens": 4096, - "chat_history": True, -} - -# Environment variables used by bootstrap modules -BOOTSTRAP_ENV_VARS = [ - # Database vars - "DB_USERNAME", - "DB_PASSWORD", - "DB_DSN", - "DB_WALLET_PASSWORD", - "TNS_ADMIN", - # Model API keys - "OPENAI_API_KEY", - "COHERE_API_KEY", - "PPLX_API_KEY", - # On-prem model URLs - "ON_PREM_OLLAMA_URL", - "ON_PREM_VLLM_URL", - "ON_PREM_HF_URL", - # OCI vars - "OCI_CLI_CONFIG_FILE", - "OCI_CLI_TENANCY", - "OCI_CLI_REGION", - "OCI_CLI_USER", - "OCI_CLI_FINGERPRINT", - "OCI_CLI_KEY_FILE", - "OCI_CLI_SECURITY_TOKEN_FILE", - "OCI_CLI_AUTH", - "OCI_GENAI_COMPARTMENT_ID", - "OCI_GENAI_REGION", - "OCI_GENAI_SERVICE_ENDPOINT", -] - - -################################################# -# Schema Factory Fixtures -################################################# - - -@pytest.fixture -def make_database(): - """Factory fixture to create Database objects.""" - - def _make_database( - name: str = "TEST_DB", - user: str = "test_user", - password: str = "test_password", - dsn: str = "localhost:1521/TESTPDB", - wallet_password: str = None, - **kwargs, - ) -> Database: - return Database( - name=name, - user=user, - password=password, - dsn=dsn, - wallet_password=wallet_password, - **kwargs, - ) - - return _make_database - - -@pytest.fixture -def make_model(): - """Factory fixture to create Model objects. - - Supports both `model_id` and `id` parameter names for backwards compatibility. - """ - - def _make_model( - model_id: str = None, - model_type: str = "ll", - provider: str = "openai", - enabled: bool = True, - api_key: str = "test-key", - api_base: str = "https://api.openai.com/v1", - **kwargs, - ) -> Model: - # Support both 'id' kwarg and 'model_id' parameter for backwards compat - resolved_id = kwargs.pop("id", None) or model_id or "gpt-4o-mini" - return Model( - id=resolved_id, - type=model_type, - provider=provider, - enabled=enabled, - api_key=api_key, - api_base=api_base, - **kwargs, - ) - - return _make_model - - -@pytest.fixture -def make_oci_config(): - """Factory fixture to create OracleCloudSettings objects. - - Note: The 'user' field requires OCID format pattern matching. - Use None to skip the user field in tests that don't need it. - """ - - def _make_oci_config( - auth_profile: str = "DEFAULT", - tenancy: str = "test-tenancy", - region: str = "us-ashburn-1", - user: str = None, # Use None by default - OCID pattern required - fingerprint: str = "test-fingerprint", - key_file: str = "/path/to/key", - **kwargs, - ) -> OracleCloudSettings: - return OracleCloudSettings( - auth_profile=auth_profile, - tenancy=tenancy, - region=region, - user=user, - fingerprint=fingerprint, - key_file=key_file, - **kwargs, - ) - - return _make_oci_config - - -@pytest.fixture -def make_ll_settings(): - """Factory fixture to create LargeLanguageSettings objects.""" - - def _make_ll_settings( - model: str = "gpt-4o-mini", - temperature: float = 0.7, - max_tokens: int = 4096, - chat_history: bool = True, - **kwargs, - ) -> LargeLanguageSettings: - return LargeLanguageSettings( - model=model, - temperature=temperature, - max_tokens=max_tokens, - chat_history=chat_history, - **kwargs, - ) - - return _make_ll_settings - - -@pytest.fixture -def make_settings(make_ll_settings): - """Factory fixture to create Settings objects.""" - - def _make_settings( - client: str = "test_client", - ll_model: LargeLanguageSettings = None, - **kwargs, - ) -> Settings: - if ll_model is None: - ll_model = make_ll_settings() - return Settings( - client=client, - ll_model=ll_model, - **kwargs, - ) - - return _make_settings - - -@pytest.fixture -def make_configuration(make_settings): - """Factory fixture to create Configuration objects.""" - - def _make_configuration( - client_settings: Settings = None, - database_configs: list = None, - model_configs: list = None, - oci_configs: list = None, - **kwargs, - ) -> Configuration: - return Configuration( - client_settings=client_settings or make_settings(), - database_configs=database_configs or [], - model_configs=model_configs or [], - oci_configs=oci_configs or [], - prompt_configs=[], - **kwargs, - ) - - return _make_configuration - - -################################################# -# Config File Fixtures -################################################# - - -@pytest.fixture -def temp_config_file(make_settings): - """Create a temporary configuration JSON file.""" - - def _create_temp_config( - client_settings: Settings = None, - database_configs: list = None, - model_configs: list = None, - oci_configs: list = None, - ): - config_data = { - "client_settings": (client_settings or make_settings()).model_dump(), - "database_configs": [ - (db if isinstance(db, dict) else db.model_dump()) - for db in (database_configs or []) - ], - "model_configs": [ - (m if isinstance(m, dict) else m.model_dump()) - for m in (model_configs or []) - ], - "oci_configs": [ - (o if isinstance(o, dict) else o.model_dump()) - for o in (oci_configs or []) - ], - "prompt_configs": [], - } - - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False, encoding="utf-8" - ) as temp_file: - json.dump(config_data, temp_file) - return Path(temp_file.name) - - return _create_temp_config - - -@pytest.fixture -def reset_config_store(): - """Reset ConfigStore singleton state before and after each test.""" - # Reset before test - ConfigStore.reset() - - yield ConfigStore - - # Reset after test - ConfigStore.reset() - - -################################################# -# Test Helper Functions (shared assertions to reduce duplication) -################################################# - - -def assert_database_list_valid(result): - """Assert that result is a valid list of Database objects.""" - assert isinstance(result, list) - assert all(isinstance(db, Database) for db in result) - - -def assert_has_default_database(result): - """Assert that DEFAULT database is in the result.""" - db_names = [db.name for db in result] - assert "DEFAULT" in db_names - - -def get_database_by_name(result, name): - """Get a database from results by name.""" - return next(db for db in result if db.name == name) - - -def assert_model_list_valid(result): - """Assert that result is a valid list of Model objects.""" - assert isinstance(result, list) - assert all(isinstance(m, Model) for m in result) - - -def get_model_by_id(result, model_id): - """Get a model from results by id.""" - return next(m for m in result if m.id == model_id) - - -################################################# -# Environment Fixtures -################################################# - - -@pytest.fixture -def clean_env(): - """Fixture to temporarily clear relevant environment variables.""" - original_values = {} - for var in BOOTSTRAP_ENV_VARS: - original_values[var] = os.environ.pop(var, None) - - yield - - # Restore original values - for var, value in original_values.items(): - if value is not None: - os.environ[var] = value - elif var in os.environ: - del os.environ[var] diff --git a/test/unit/__init__.py b/test/unit/__init__.py deleted file mode 100644 index 06825972..00000000 --- a/test/unit/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Unit test package diff --git a/test/unit/server/__init__.py b/test/unit/server/__init__.py deleted file mode 100644 index bc4d60b5..00000000 --- a/test/unit/server/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Server unit test package diff --git a/test/unit/server/api/__init__.py b/test/unit/server/api/__init__.py deleted file mode 100644 index b4333d68..00000000 --- a/test/unit/server/api/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# API unit test package diff --git a/test/unit/server/api/conftest.py b/test/unit/server/api/conftest.py deleted file mode 100644 index 8ec4542a..00000000 --- a/test/unit/server/api/conftest.py +++ /dev/null @@ -1,194 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Pytest fixtures for server/api unit tests. -Provides factory fixtures for creating test objects. -""" - -# pylint: disable=redefined-outer-name unused-import -# Pytest fixtures use parameter injection where fixture names match parameters - -from unittest.mock import MagicMock, AsyncMock - -# Re-export shared fixtures for pytest discovery (before third-party imports per pylint) -from test.shared_fixtures import ( - make_database, - make_model, - make_oci_config, - make_ll_settings, - make_settings, - make_configuration, -) - -import pytest - -from common.schema import ( - DatabaseAuth, - DatabaseVectorStorage, - ChatRequest, -) - - -@pytest.fixture -def make_database_auth(): - """Factory fixture to create DatabaseAuth objects.""" - - def _make_database_auth(**overrides) -> DatabaseAuth: - defaults = { - "user": "test_user", - "password": "test_password", - "dsn": "localhost:1521/TESTPDB", - "wallet_password": None, - } - defaults.update(overrides) - return DatabaseAuth(**defaults) - - return _make_database_auth - - -@pytest.fixture -def make_vector_store(): - """Factory fixture to create DatabaseVectorStorage objects.""" - - def _make_vector_store( - vector_store: str = "VS_TEST", - model: str = "text-embedding-3-small", - chunk_size: int = 1000, - chunk_overlap: int = 200, - **kwargs, - ) -> DatabaseVectorStorage: - return DatabaseVectorStorage( - vector_store=vector_store, - model=model, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - **kwargs, - ) - - return _make_vector_store - - -@pytest.fixture -def make_chat_request(): - """Factory fixture to create ChatRequest objects.""" - - def _make_chat_request( - content: str = "Hello", - role: str = "user", - **kwargs, - ) -> ChatRequest: - return ChatRequest( - messages=[{"role": role, "content": content}], - **kwargs, - ) - - return _make_chat_request - - -@pytest.fixture -def make_mcp_prompt(): - """Factory fixture to create MCP prompt mock objects.""" - - def _make_mcp_prompt( - name: str = "optimizer_test-prompt", - description: str = "Test prompt description", - text: str = "Test prompt text content", - ): - mock_prompt = MagicMock() - mock_prompt.name = name - mock_prompt.description = description - mock_prompt.text = text - mock_prompt.model_dump.return_value = { - "name": name, - "description": description, - "text": text, - } - return mock_prompt - - return _make_mcp_prompt - - -@pytest.fixture -def mock_fastmcp(): - """Create a mock FastMCP application.""" - mock_mcp = MagicMock() - mock_mcp.list_tools = AsyncMock(return_value=[]) - mock_mcp.list_resources = AsyncMock(return_value=[]) - mock_mcp.list_prompts = AsyncMock(return_value=[]) - return mock_mcp - - -@pytest.fixture -def mock_mcp_client(): - """Create a mock MCP client.""" - mock_client = MagicMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mock_client.list_prompts = AsyncMock(return_value=[]) - mock_client.get_prompt = AsyncMock(return_value=MagicMock()) - mock_client.close = AsyncMock() - return mock_client - - -@pytest.fixture -def mock_db_connection(): - """Create a mock database connection for endpoint tests. - - This mock is used by v1 endpoint tests that mock the underlying - database utilities. It provides a simple MagicMock that can be - passed around without needing a real database connection. - - For tests that need actual database operations, use the real - db_connection or db_transaction fixtures from test/conftest.py. - """ - mock_conn = MagicMock() - mock_conn.cursor.return_value.__enter__ = MagicMock() - mock_conn.cursor.return_value.__exit__ = MagicMock() - mock_conn.commit = MagicMock() - mock_conn.rollback = MagicMock() - mock_conn.close = MagicMock() - return mock_conn - - -@pytest.fixture -def mock_request_app_state(mock_fastmcp): - """Create a mock FastAPI request with app state.""" - mock_request = MagicMock() - mock_request.app.state.fastmcp_app = mock_fastmcp - return mock_request - - -@pytest.fixture -def mock_bootstrap(): - """Create mocks for bootstrap module dependencies.""" - return { - "databases": [], - "models": [], - "oci_configs": [], - "prompts": [], - "settings": [], - } - - -def create_mock_aiohttp_session(mock_session_class, mock_response): - """Helper to create a mock aiohttp ClientSession with response. - - This is a shared utility for tests that need to mock aiohttp.ClientSession. - It properly sets up async context manager behavior for session.get(). - - Args: - mock_session_class: The patched aiohttp.ClientSession class - mock_response: The mock response object to return from session.get() - - Returns: - The configured mock session object - """ - mock_session = AsyncMock() - mock_session.get = MagicMock( - return_value=AsyncMock(__aenter__=AsyncMock(return_value=mock_response)) - ) - mock_session.__aenter__ = AsyncMock(return_value=mock_session) - mock_session.__aexit__ = AsyncMock() - mock_session_class.return_value = mock_session - return mock_session diff --git a/test/unit/server/api/utils/__init__.py b/test/unit/server/api/utils/__init__.py deleted file mode 100644 index 9d9b7b29..00000000 --- a/test/unit/server/api/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Utils unit test package diff --git a/test/unit/server/api/utils/test_utils_chat.py b/test/unit/server/api/utils/test_utils_chat.py deleted file mode 100644 index 22f952a6..00000000 --- a/test/unit/server/api/utils/test_utils_chat.py +++ /dev/null @@ -1,312 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/utils/chat.py -Tests for chat completion utility functions. -""" - -from unittest.mock import patch, MagicMock -import pytest - -from server.api.utils import chat as utils_chat -from server.api.utils.models import UnknownModelError -from common.schema import ChatRequest - - -class TestCompletionGenerator: - """Tests for the completion_generator function.""" - - @pytest.mark.asyncio - @patch("server.api.utils.chat.utils_settings.get_client") - @patch("server.api.utils.chat.utils_oci.get") - @patch("server.api.utils.chat.utils_models.get_litellm_config") - @patch("server.api.utils.chat.chatbot_graph") - async def test_completion_generator_completions_mode( - self, - mock_graph, - mock_get_config, - mock_oci_get, - mock_get_client, - make_settings, - make_chat_request, - make_oci_config, - ): - """completion_generator should yield final response in completions mode.""" - mock_get_client.return_value = make_settings() - mock_oci_get.return_value = make_oci_config() - mock_get_config.return_value = {"model": "gpt-4o-mini"} - - async def mock_astream(**_kwargs): - yield {"completion": {"choices": [{"message": {"content": "Hello!"}}]}} - - mock_graph.astream = mock_astream - - request = make_chat_request(content="Hi") - results = [] - async for output in utils_chat.completion_generator("test_client", request, "completions"): - results.append(output) - - assert len(results) == 1 - assert results[0]["choices"][0]["message"]["content"] == "Hello!" - - @pytest.mark.asyncio - @patch("server.api.utils.chat.utils_settings.get_client") - @patch("server.api.utils.chat.utils_oci.get") - @patch("server.api.utils.chat.utils_models.get_litellm_config") - @patch("server.api.utils.chat.chatbot_graph") - async def test_completion_generator_streams_mode( - self, - mock_graph, - mock_get_config, - mock_oci_get, - mock_get_client, - make_settings, - make_chat_request, - make_oci_config, - ): - """completion_generator should yield stream chunks in streams mode.""" - mock_get_client.return_value = make_settings() - mock_oci_get.return_value = make_oci_config() - mock_get_config.return_value = {"model": "gpt-4o-mini"} - - async def mock_astream(**_kwargs): - yield {"stream": "Hello"} - yield {"stream": " World"} - yield {"completion": {"choices": []}} - - mock_graph.astream = mock_astream - - request = make_chat_request(content="Hi") - results = [] - async for output in utils_chat.completion_generator("test_client", request, "streams"): - results.append(output) - - # Should have 3 outputs: 2 stream chunks + stream_finished - assert len(results) == 3 - assert results[0] == b"Hello" - assert results[1] == b" World" - assert results[2] == "[stream_finished]" - - @pytest.mark.asyncio - @patch("server.api.utils.chat.utils_settings.get_client") - @patch("server.api.utils.chat.utils_oci.get") - @patch("server.api.utils.chat.utils_models.get_litellm_config") - @patch("server.api.utils.chat.completion") - async def test_completion_generator_unknown_model_error( - self, - mock_completion, - mock_get_config, - mock_oci_get, - mock_get_client, - make_settings, - make_chat_request, - make_oci_config, - ): - """completion_generator should return error response on UnknownModelError.""" - mock_get_client.return_value = make_settings() - mock_oci_get.return_value = make_oci_config() - mock_get_config.side_effect = UnknownModelError("Model not found") - - mock_error_response = MagicMock() - mock_error_response.choices = [MagicMock()] - mock_error_response.choices[0].message.content = "I'm unable to initialise the Language Model." - mock_completion.return_value = mock_error_response - - request = make_chat_request(content="Hi") - results = [] - async for output in utils_chat.completion_generator("test_client", request, "completions"): - results.append(output) - - assert len(results) == 1 - mock_completion.assert_called_once() - # Verify mock_response was used - call_kwargs = mock_completion.call_args.kwargs - assert "mock_response" in call_kwargs - - @pytest.mark.asyncio - @patch("server.api.utils.chat.utils_settings.get_client") - @patch("server.api.utils.chat.utils_oci.get") - @patch("server.api.utils.chat.utils_models.get_litellm_config") - @patch("server.api.utils.chat.chatbot_graph") - async def test_completion_generator_uses_request_model( - self, mock_graph, mock_get_config, mock_oci_get, mock_get_client, make_settings, make_oci_config - ): - """completion_generator should use model from request if provided.""" - mock_get_client.return_value = make_settings() - mock_oci_get.return_value = make_oci_config() - mock_get_config.return_value = {"model": "claude-3"} - - async def mock_astream(**_kwargs): - yield {"completion": {}} - - mock_graph.astream = mock_astream - - request = ChatRequest(messages=[{"role": "user", "content": "Hi"}], model="claude-3") - async for _ in utils_chat.completion_generator("test_client", request, "completions"): - pass - - # get_litellm_config should be called with the request model - call_args = mock_get_config.call_args[0] - assert call_args[0]["model"] == "claude-3" - - @pytest.mark.asyncio - @patch("server.api.utils.chat.utils_settings.get_client") - @patch("server.api.utils.chat.utils_oci.get") - @patch("server.api.utils.chat.utils_models.get_litellm_config") - @patch("server.api.utils.chat.chatbot_graph") - async def test_completion_generator_uses_settings_model_when_not_in_request( - self, - mock_graph, - mock_get_config, - mock_oci_get, - mock_get_client, - make_settings, - make_chat_request, - make_oci_config, - make_ll_settings, - ): - """completion_generator should use model from settings when not in request.""" - settings = make_settings(ll_model=make_ll_settings(model="gpt-4-turbo")) - mock_get_client.return_value = settings - mock_oci_get.return_value = make_oci_config() - mock_get_config.return_value = {"model": "gpt-4-turbo"} - - async def mock_astream(**_kwargs): - yield {"completion": {}} - - mock_graph.astream = mock_astream - - request = make_chat_request(content="Hi") # No model specified - async for _ in utils_chat.completion_generator("test_client", request, "completions"): - pass - - # get_litellm_config should be called with settings model - call_args = mock_get_config.call_args[0] - assert call_args[0]["model"] == "gpt-4-turbo" - - @pytest.mark.asyncio - @patch("server.api.utils.chat.utils_settings.get_client") - @patch("server.api.utils.chat.utils_oci.get") - @patch("server.api.utils.chat.utils_models.get_litellm_config") - @patch("server.api.utils.chat.utils_databases.get_client_database") - @patch("server.api.utils.chat.utils_models.get_client_embed") - @patch("server.api.utils.chat.chatbot_graph") - async def test_completion_generator_with_vector_search_enabled( - self, - mock_graph, - mock_get_embed, - mock_get_db, - mock_get_config, - mock_oci_get, - mock_get_client, - make_settings, - make_chat_request, - make_oci_config, - ): - """completion_generator should setup db connection when vector search enabled.""" - settings = make_settings() - settings.tools_enabled = ["Vector Search"] - mock_get_client.return_value = settings - mock_oci_get.return_value = make_oci_config() - mock_get_config.return_value = {"model": "gpt-4o-mini"} - - mock_db = MagicMock() - mock_db.connection = MagicMock() - mock_get_db.return_value = mock_db - mock_get_embed.return_value = MagicMock() - - async def mock_astream(**_kwargs): - yield {"completion": {}} - - mock_graph.astream = mock_astream - - request = make_chat_request(content="Hi") - async for _ in utils_chat.completion_generator("test_client", request, "completions"): - pass - - mock_get_db.assert_called_once_with("test_client", False) - mock_get_embed.assert_called_once() - - @pytest.mark.asyncio - @patch("server.api.utils.chat.utils_settings.get_client") - @patch("server.api.utils.chat.utils_oci.get") - @patch("server.api.utils.chat.utils_models.get_litellm_config") - @patch("server.api.utils.chat.chatbot_graph") - async def test_completion_generator_passes_correct_config( - self, - mock_graph, - mock_get_config, - mock_oci_get, - mock_get_client, - make_settings, - make_chat_request, - make_oci_config, - ): - """completion_generator should pass correct config to chatbot_graph.""" - settings = make_settings() - mock_get_client.return_value = settings - mock_oci_get.return_value = make_oci_config() - mock_get_config.return_value = {"model": "gpt-4o-mini"} - - captured_kwargs = {} - - async def mock_astream(**kwargs): - captured_kwargs.update(kwargs) - yield {"completion": {}} - - mock_graph.astream = mock_astream - - request = make_chat_request(content="Test message") - async for _ in utils_chat.completion_generator("test_client", request, "completions"): - pass - - assert captured_kwargs["stream_mode"] == "custom" - assert captured_kwargs["config"]["configurable"]["thread_id"] == "test_client" - assert captured_kwargs["config"]["metadata"]["streaming"] is False - - @pytest.mark.asyncio - @patch("server.api.utils.chat.utils_settings.get_client") - @patch("server.api.utils.chat.utils_oci.get") - @patch("server.api.utils.chat.utils_models.get_litellm_config") - @patch("server.api.utils.chat.chatbot_graph") - async def test_completion_generator_streaming_metadata( - self, - mock_graph, - mock_get_config, - mock_oci_get, - mock_get_client, - make_settings, - make_chat_request, - make_oci_config, - ): - """completion_generator should set streaming=True for streams mode.""" - mock_get_client.return_value = make_settings() - mock_oci_get.return_value = make_oci_config() - mock_get_config.return_value = {"model": "gpt-4o-mini"} - - captured_kwargs = {} - - async def mock_astream(**kwargs): - captured_kwargs.update(kwargs) - yield {"completion": {}} - - mock_graph.astream = mock_astream - - request = make_chat_request(content="Test") - async for _ in utils_chat.completion_generator("test_client", request, "streams"): - pass - - assert captured_kwargs["config"]["metadata"]["streaming"] is True - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(utils_chat, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert utils_chat.logger.name == "api.utils.chat" diff --git a/test/unit/server/api/utils/test_utils_databases.py b/test/unit/server/api/utils/test_utils_databases.py deleted file mode 100644 index ebc92d80..00000000 --- a/test/unit/server/api/utils/test_utils_databases.py +++ /dev/null @@ -1,657 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/utils/databases.py -Tests for database utility functions. - -Uses hybrid approach: -- Real Oracle database for connection/SQL execution tests -- Mocks for pure Python logic tests (in-memory operations, exception handling) -""" - -# pylint: disable=too-few-public-methods - -from test.conftest import TEST_CONFIG -from unittest.mock import patch, MagicMock - -import pytest -import oracledb - -from common.schema import DatabaseSettings -from server.api.utils import databases as utils_databases -from server.api.utils.databases import DbException, ExistsDatabaseError, UnknownDatabaseError - - -class TestDbException: - """Tests for DbException class.""" - - def test_db_exception_init(self): - """DbException should store status_code and detail.""" - exc = DbException(status_code=404, detail="Not found") - assert exc.status_code == 404 - assert exc.detail == "Not found" - - def test_db_exception_message(self): - """DbException should use detail as message.""" - exc = DbException(status_code=500, detail="Server error") - assert str(exc) == "Server error" - - -class TestExistsDatabaseError: - """Tests for ExistsDatabaseError class.""" - - def test_exists_database_error_is_value_error(self): - """ExistsDatabaseError should inherit from ValueError.""" - exc = ExistsDatabaseError("Database exists") - assert isinstance(exc, ValueError) - - -class TestUnknownDatabaseError: - """Tests for UnknownDatabaseError class.""" - - def test_unknown_database_error_is_value_error(self): - """UnknownDatabaseError should inherit from ValueError.""" - exc = UnknownDatabaseError("Database not found") - assert isinstance(exc, ValueError) - - -class TestCreate: - """Tests for the create function.""" - - @patch("server.api.utils.databases.get") - @patch("server.api.utils.databases.DATABASE_OBJECTS", []) - def test_create_success(self, mock_get, make_database): - """create should add database to DATABASE_OBJECTS.""" - mock_get.side_effect = [UnknownDatabaseError("Not found"), [make_database()]] - database = make_database(name="NEW_DB") - - result = utils_databases.create(database) - - assert result is not None - - @patch("server.api.utils.databases.get") - def test_create_raises_exists_error(self, mock_get, make_database): - """create should raise ExistsDatabaseError if database exists.""" - mock_get.return_value = [make_database(name="EXISTING_DB")] - database = make_database(name="EXISTING_DB") - - with pytest.raises(ExistsDatabaseError): - utils_databases.create(database) - - @patch("server.api.utils.databases.get") - def test_create_raises_value_error_missing_fields(self, mock_get, make_database): - """create should raise ValueError if required fields missing.""" - mock_get.side_effect = UnknownDatabaseError("Not found") - database = make_database(user=None) - - with pytest.raises(ValueError) as exc_info: - utils_databases.create(database) - - assert "user" in str(exc_info.value) - - -class TestGet: - """Tests for the get function.""" - - @patch("server.api.utils.databases.DATABASE_OBJECTS") - def test_get_all_databases(self, mock_objects, make_database): - """get should return all databases when no name provided.""" - mock_objects.__iter__ = lambda _: iter([make_database(name="DB1"), make_database(name="DB2")]) - mock_objects.__len__ = lambda _: 2 - - result = utils_databases.get() - - assert len(result) == 2 - - @patch("server.api.utils.databases.DATABASE_OBJECTS") - def test_get_specific_database(self, mock_objects, make_database): - """get should return specific database when name provided.""" - db1 = make_database(name="DB1") - db2 = make_database(name="DB2") - mock_objects.__iter__ = lambda _: iter([db1, db2]) - mock_objects.__len__ = lambda _: 2 - - result = utils_databases.get(name="DB1") - - assert len(result) == 1 - assert result[0].name == "DB1" - - @patch("server.api.utils.databases.DATABASE_OBJECTS") - def test_get_raises_unknown_error(self, mock_objects): - """get should raise UnknownDatabaseError if name not found.""" - mock_objects.__iter__ = lambda _: iter([]) - mock_objects.__len__ = lambda _: 0 - - with pytest.raises(UnknownDatabaseError): - utils_databases.get(name="NONEXISTENT") - - -class TestDelete: - """Tests for the delete function.""" - - def test_delete_removes_database(self, make_database): - """delete should remove database from DATABASE_OBJECTS.""" - db1 = make_database(name="DB1") - db2 = make_database(name="DB2") - - with patch("server.api.utils.databases.DATABASE_OBJECTS", [db1, db2]) as mock_objects: - utils_databases.delete("DB1") - assert len(mock_objects) == 1 - assert mock_objects[0].name == "DB2" - - -class TestConnect: - """Tests for the connect function. - - Uses real database for success case, mocks for error code testing - (since we can't easily trigger specific Oracle errors). - """ - - def test_connect_success_real_db(self, db_container, make_database): - """connect should return connection on success (real database).""" - # pylint: disable=unused-argument - config = make_database( - user=TEST_CONFIG["db_username"], - password=TEST_CONFIG["db_password"], - dsn=TEST_CONFIG["db_dsn"], - ) - - result = utils_databases.connect(config) - - assert result is not None - assert result.is_healthy() - result.close() - - def test_connect_raises_value_error_missing_details(self, make_database): - """connect should raise ValueError if connection details missing.""" - config = make_database(user=None, password=None, dsn=None) - - with pytest.raises(ValueError) as exc_info: - utils_databases.connect(config) - - assert "missing connection details" in str(exc_info.value) - - def test_connect_raises_permission_error_invalid_credentials(self, db_container, make_database): - """connect should raise PermissionError on invalid credentials (real database).""" - # pylint: disable=unused-argument - config = make_database( - user="INVALID_USER", - password="wrong_password", - dsn=TEST_CONFIG["db_dsn"], - ) - - with pytest.raises(PermissionError): - utils_databases.connect(config) - - def test_connect_raises_connection_error_invalid_dsn(self, db_container, make_database): - """connect should raise ConnectionError on invalid service name (real database). - - Note: DPY-6005 (cannot connect) wraps DPY-6001 (service not registered), - and the current implementation maps DPY-6005 to ConnectionError. - """ - # pylint: disable=unused-argument - config = make_database( - user=TEST_CONFIG["db_username"], - password=TEST_CONFIG["db_password"], - dsn="//localhost:1525/NONEXISTENT_SERVICE", - ) - - with pytest.raises(ConnectionError): - utils_databases.connect(config) - - @patch("server.api.utils.databases.oracledb.connect") - def test_connect_raises_connection_error_on_oserror(self, mock_connect, make_database): - """connect should raise ConnectionError on OSError (mocked - can't easily trigger).""" - mock_connect.side_effect = OSError("Network unreachable") - config = make_database() - - with pytest.raises(ConnectionError): - utils_databases.connect(config) - - @patch("server.api.utils.databases.oracledb.connect") - def test_connect_wallet_location_defaults_to_config_dir(self, mock_connect, make_database): - """connect should default wallet_location to config_dir if not set (mocked - verifies call args).""" - mock_conn = MagicMock() - mock_connect.return_value = mock_conn - config = make_database(wallet_password="secret", config_dir="/path/to/config") - - utils_databases.connect(config) - - call_kwargs = mock_connect.call_args.kwargs - assert call_kwargs.get("wallet_location") == "/path/to/config" - - @patch("server.api.utils.databases.oracledb.connect") - def test_connect_raises_permission_error_on_ora_28009(self, mock_connect, make_database): - """connect should raise PermissionError with custom message on ORA-28009 (mocked).""" - # Create a mock error object with full_code and message - mock_error = MagicMock() - mock_error.full_code = "ORA-28009" - mock_error.message = "connection not allowed" - mock_connect.side_effect = oracledb.DatabaseError(mock_error) - config = make_database(user="SYS") - - with pytest.raises(PermissionError) as exc_info: - utils_databases.connect(config) - - assert "Connecting as SYS is not permitted" in str(exc_info.value) - - @patch("server.api.utils.databases.oracledb.connect") - def test_connect_reraises_unmapped_database_error(self, mock_connect, make_database): - """connect should re-raise unmapped DatabaseError codes (mocked).""" - # Create a mock error object with an unmapped error code - mock_error = MagicMock() - mock_error.full_code = "ORA-12345" - mock_error.message = "some other error" - mock_connect.side_effect = oracledb.DatabaseError(mock_error) - config = make_database() - - with pytest.raises(oracledb.DatabaseError): - utils_databases.connect(config) - - -class TestDisconnect: - """Tests for the disconnect function.""" - - def test_disconnect_closes_connection(self): - """disconnect should call close on connection.""" - mock_conn = MagicMock() - - utils_databases.disconnect(mock_conn) - - mock_conn.close.assert_called_once() - - -class TestExecuteSql: - """Tests for the execute_sql function. - - Uses real database for actual SQL execution tests. - """ - - def test_execute_sql_returns_rows(self, db_transaction): - """execute_sql should return query results (real database).""" - result = utils_databases.execute_sql(db_transaction, "SELECT 'val1' AS col1, 'val2' AS col2 FROM dual") - - assert len(result) == 1 - assert result[0] == ("val1", "val2") - - def test_execute_sql_with_binds(self, db_transaction): - """execute_sql should pass binds to cursor (real database).""" - result = utils_databases.execute_sql( - db_transaction, "SELECT :val AS result FROM dual", {"val": "test_value"} - ) - - assert result[0] == ("test_value",) - - def test_execute_sql_handles_clob_columns(self, db_transaction): - """execute_sql should read CLOB column values (real database).""" - # Create a CLOB using TO_CLOB function - result = utils_databases.execute_sql( - db_transaction, "SELECT TO_CLOB('CLOB content here') AS clob_col FROM dual" - ) - - # Result should have the CLOB content read as string - assert len(result) == 1 - assert "CLOB content here" in str(result[0]) - - def test_execute_sql_returns_dbms_output(self, db_transaction): - """execute_sql should return DBMS_OUTPUT when no rows (real database).""" - result = utils_databases.execute_sql( - db_transaction, - """ - BEGIN - DBMS_OUTPUT.ENABLE; - DBMS_OUTPUT.PUT_LINE('Test DBMS Output'); - END; - """, - ) - - assert "Test DBMS Output" in str(result) - - def test_execute_sql_multiple_rows(self, db_transaction): - """execute_sql should handle multiple rows (real database).""" - result = utils_databases.execute_sql( - db_transaction, - """ - SELECT LEVEL AS num FROM dual CONNECT BY LEVEL <= 3 - """, - ) - - assert len(result) == 3 - assert result[0] == (1,) - assert result[1] == (2,) - assert result[2] == (3,) - - def test_execute_sql_logs_table_exists_error(self, db_connection, caplog): - """execute_sql should log ORA-00955 table exists error (real database). - - Note: Due to a bug in the source code (two if statements instead of elif), - the function logs 'Table exists' but still raises. This test verifies - the logging behavior and that the error is raised. - """ - cursor = db_connection.cursor() - table_name = "TEST_DUPLICATE_TABLE" - - try: - # Create table first - cursor.execute(f"CREATE TABLE {table_name} (id NUMBER)") - db_connection.commit() - - # Try to create it again - logs 'Table exists' but raises due to bug - with pytest.raises(oracledb.DatabaseError): - utils_databases.execute_sql( - db_connection, - f"CREATE TABLE {table_name} (id NUMBER)", - ) - - # Verify the logging happened - assert "Table exists" in caplog.text - - finally: - try: - cursor.execute(f"DROP TABLE {table_name} PURGE") - db_connection.commit() - except oracledb.DatabaseError: - pass - cursor.close() - - def test_execute_sql_handles_table_not_exists_error(self, db_connection, caplog): - """execute_sql should handle ORA-00942 table not exists error (real database). - - The function logs 'Table does not exist' and returns None (doesn't raise) - for error code 942. - """ - # Try to select from a non-existent table - result = utils_databases.execute_sql( - db_connection, - "SELECT * FROM NONEXISTENT_TABLE_12345", - ) - - # Should not raise, returns None - assert result is None - - # Verify the logging happened - assert "Table does not exist" in caplog.text - - def test_execute_sql_raises_on_other_database_error(self, db_transaction): - """execute_sql should raise on other DatabaseError codes (real database).""" - # Invalid SQL syntax should raise - with pytest.raises(oracledb.DatabaseError): - utils_databases.execute_sql(db_transaction, "INVALID SQL SYNTAX HERE") - - def test_execute_sql_raises_on_interface_error(self): - """execute_sql should raise on InterfaceError (mocked).""" - mock_conn = MagicMock() - mock_cursor = MagicMock() - mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) - mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) - mock_cursor.callproc.side_effect = oracledb.InterfaceError("Interface error") - - with pytest.raises(oracledb.InterfaceError): - utils_databases.execute_sql(mock_conn, "SELECT 1 FROM dual") - - def test_execute_sql_raises_on_database_error_no_args(self): - """execute_sql should raise on DatabaseError with no args (mocked).""" - mock_conn = MagicMock() - mock_cursor = MagicMock() - mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) - mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) - # DatabaseError with empty args - mock_cursor.callproc.side_effect = oracledb.DatabaseError() - - with pytest.raises(oracledb.DatabaseError): - utils_databases.execute_sql(mock_conn, "SELECT 1 FROM dual") - - -class TestDropVs: - """Tests for the drop_vs function.""" - - @patch("server.api.utils.databases.LangchainVS.drop_table_purge") - def test_drop_vs_calls_langchain(self, mock_drop): - """drop_vs should call LangchainVS.drop_table_purge.""" - mock_conn = MagicMock() - - utils_databases.drop_vs(mock_conn, "VS_TEST") - - mock_drop.assert_called_once_with(mock_conn, "VS_TEST") - - -class TestGetDatabases: - """Tests for the get_databases function.""" - - @patch("server.api.utils.databases.get") - def test_get_databases_without_name(self, mock_get, make_database): - """get_databases should return all databases without name.""" - mock_get.return_value = [make_database(name="DB1"), make_database(name="DB2")] - - result = utils_databases.get_databases() - - assert len(result) == 2 - - @patch("server.api.utils.databases.get") - def test_get_databases_with_name(self, mock_get, make_database): - """get_databases should return single database with name.""" - mock_get.return_value = [make_database(name="DB1")] - - result = utils_databases.get_databases(db_name="DB1") - - assert result.name == "DB1" - - @patch("server.api.utils.databases.get") - @patch("server.api.utils.databases.connect") - @patch("server.api.utils.databases._get_vs") - def test_get_databases_with_validate(self, mock_get_vs, mock_connect, mock_get, make_database): - """get_databases should validate connections when validate=True.""" - db = make_database(name="DB1") - mock_get.return_value = [db] - mock_conn = MagicMock() - mock_connect.return_value = mock_conn - mock_get_vs.return_value = [] - - result = utils_databases.get_databases(validate=True) - - mock_connect.assert_called_once() - assert result[0].connected is True - - @patch("server.api.utils.databases.get") - @patch("server.api.utils.databases.connect") - def test_get_databases_validate_handles_connection_error(self, mock_connect, mock_get, make_database): - """get_databases should continue on connection error during validation.""" - db = make_database(name="DB1") - mock_get.return_value = [db] - mock_connect.side_effect = ConnectionError("Cannot connect") - - result = utils_databases.get_databases(validate=True) - - assert len(result) == 1 - # Should not crash, just continue - - -class TestGetClientDatabase: - """Tests for the get_client_database function.""" - - @patch("server.api.utils.databases.utils_settings.get_client") - @patch("server.api.utils.databases.get_databases") - def test_get_client_database_default(self, mock_get_databases, mock_get_client, make_settings, make_database): - """get_client_database should default to DEFAULT database.""" - mock_get_client.return_value = make_settings() - mock_get_databases.return_value = make_database(name="DEFAULT") - - utils_databases.get_client_database("test_client") - - mock_get_databases.assert_called_once_with(db_name="DEFAULT", validate=False) - - @patch("server.api.utils.databases.utils_settings.get_client") - @patch("server.api.utils.databases.get_databases") - def test_get_client_database_from_database_settings( - self, mock_get_databases, mock_get_client, make_settings, make_database - ): - """get_client_database should use database alias from Settings.database.""" - settings = make_settings() - settings.database = DatabaseSettings(alias="CUSTOM_DB") - mock_get_client.return_value = settings - mock_get_databases.return_value = make_database(name="CUSTOM_DB") - - utils_databases.get_client_database("test_client") - - # Should use the alias from Settings.database - mock_get_databases.assert_called_once_with(db_name="CUSTOM_DB", validate=False) - - @patch("server.api.utils.databases.utils_settings.get_client") - @patch("server.api.utils.databases.get_databases") - def test_get_client_database_with_validate( - self, mock_get_databases, mock_get_client, make_settings, make_database - ): - """get_client_database should pass validate flag.""" - mock_get_client.return_value = make_settings() - mock_get_databases.return_value = make_database() - - utils_databases.get_client_database("test_client", validate=True) - - mock_get_databases.assert_called_once_with(db_name="DEFAULT", validate=True) - - -class TestTestConnection: # pylint: disable=protected-access - """Tests for the _test function.""" - - def test_test_connection_active(self, make_database): - """_test should set connected=True when ping succeeds.""" - config = make_database() - mock_conn = MagicMock() - mock_conn.ping.return_value = None - config.set_connection(mock_conn) - - utils_databases._test(config) - - assert config.connected is True - - @patch("server.api.utils.databases.connect") - def test_test_connection_refreshes_on_database_error(self, mock_connect, make_database): - """_test should refresh connection on DatabaseError.""" - config = make_database() - mock_conn = MagicMock() - mock_conn.ping.side_effect = oracledb.DatabaseError("Connection lost") - config.set_connection(mock_conn) - mock_connect.return_value = MagicMock() - - utils_databases._test(config) - - mock_connect.assert_called_once_with(config) - - def test_test_raises_db_exception_on_value_error(self, make_database): - """_test should raise DbException on ValueError.""" - config = make_database() - mock_conn = MagicMock() - mock_conn.ping.side_effect = ValueError("Invalid config") - config.set_connection(mock_conn) - - with pytest.raises(DbException) as exc_info: - utils_databases._test(config) - - assert exc_info.value.status_code == 400 - - def test_test_raises_db_exception_on_permission_error(self, make_database): - """_test should raise DbException on PermissionError.""" - config = make_database() - mock_conn = MagicMock() - mock_conn.ping.side_effect = PermissionError("Access denied") - config.set_connection(mock_conn) - - with pytest.raises(DbException) as exc_info: - utils_databases._test(config) - - assert exc_info.value.status_code == 401 - - def test_test_raises_db_exception_on_connection_error(self, make_database): - """_test should raise DbException on ConnectionError.""" - config = make_database() - mock_conn = MagicMock() - mock_conn.ping.side_effect = ConnectionError("Network error") - config.set_connection(mock_conn) - - with pytest.raises(DbException) as exc_info: - utils_databases._test(config) - - assert exc_info.value.status_code == 503 - - def test_test_raises_db_exception_on_generic_exception(self, make_database): - """_test should raise DbException with 500 on generic Exception.""" - config = make_database() - mock_conn = MagicMock() - mock_conn.ping.side_effect = RuntimeError("Unexpected error") - config.set_connection(mock_conn) - - with pytest.raises(DbException) as exc_info: - utils_databases._test(config) - - assert exc_info.value.status_code == 500 - assert "Unexpected error" in exc_info.value.detail - - -class TestGetVs: # pylint: disable=protected-access - """Tests for the _get_vs function. - - Uses real database - queries user_tables for vector store metadata. - Note: Results depend on actual tables in test database schema. - """ - - def test_get_vs_returns_list(self, db_transaction): - """_get_vs should return a list (real database).""" - result = utils_databases._get_vs(db_transaction) - - # Should return a list (may be empty if no vector stores exist) - assert isinstance(result, list) - - def test_get_vs_empty_for_clean_schema(self, db_transaction): - """_get_vs should return empty list when no vector stores (real database).""" - # In a clean test schema, there should be no vector stores - result = utils_databases._get_vs(db_transaction) - - # Either empty or returns actual vector stores if they exist - assert isinstance(result, list) - - def test_get_vs_parses_genai_comment(self, db_connection): - """_get_vs should parse GENAI comment JSON and return DatabaseVectorStorage (real database).""" - cursor = db_connection.cursor() - table_name = "VS_TEST_TABLE" - - try: - # Create a test table - cursor.execute(f"CREATE TABLE {table_name} (id NUMBER, data VARCHAR2(100))") - - # Add GENAI comment with JSON metadata (matching the expected format) - comment_json = '{"description": "Test vector store"}' - cursor.execute(f"COMMENT ON TABLE {table_name} IS 'GENAI: {comment_json}'") - db_connection.commit() - - # Test _get_vs - result = utils_databases._get_vs(db_connection) - - # Should find our test table - vs_names = [vs.vector_store for vs in result] - assert table_name in vs_names - - # Find our test vector store and verify parsed data - test_vs = next(vs for vs in result if vs.vector_store == table_name) - assert test_vs.description == "Test vector store" - - finally: - # Cleanup - drop table - try: - cursor.execute(f"DROP TABLE {table_name} PURGE") - db_connection.commit() - except oracledb.DatabaseError: - pass - cursor.close() - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(utils_databases, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert utils_databases.logger.name == "api.utils.database" diff --git a/test/unit/server/api/utils/test_utils_embed.py b/test/unit/server/api/utils/test_utils_embed.py deleted file mode 100644 index dbaf2c4d..00000000 --- a/test/unit/server/api/utils/test_utils_embed.py +++ /dev/null @@ -1,805 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/utils/embed.py -Tests for document embedding and vector store utility functions. - -Uses hybrid approach: -- Real Oracle database for vector store query tests -- Mocks for file processing logic (document loaders, splitting, etc.) -""" - -# pylint: disable=too-few-public-methods - -import json -import os -from unittest.mock import patch, MagicMock -import pytest - -from langchain_core.documents import Document as LangchainDocument - -from server.api.utils import embed as utils_embed - - -class TestUpdateVsComment: - """Tests for the update_vs_comment function.""" - - @patch("server.api.utils.embed.utils_databases.connect") - @patch("server.api.utils.embed.functions.get_vs_table") - @patch("server.api.utils.embed.utils_databases.execute_sql") - @patch("server.api.utils.embed.utils_databases.disconnect") - def test_update_vs_comment_success( - self, mock_disconnect, mock_execute_sql, mock_get_vs_table, mock_connect, make_database, make_vector_store - ): - """update_vs_comment should execute comment SQL.""" - mock_conn = MagicMock() - mock_connect.return_value = mock_conn - mock_get_vs_table.return_value = ("VS_TEST", '{"alias": "test"}') - - db_details = make_database() - vector_store = make_vector_store(vector_store="VS_TEST") - - utils_embed.update_vs_comment(vector_store=vector_store, db_details=db_details) - - mock_connect.assert_called_once_with(db_details) - mock_execute_sql.assert_called_once() - mock_disconnect.assert_called_once_with(mock_conn) - - @patch("server.api.utils.embed.utils_databases.connect") - @patch("server.api.utils.embed.functions.get_vs_table") - @patch("server.api.utils.embed.utils_databases.execute_sql") - @patch("server.api.utils.embed.utils_databases.disconnect") - def test_update_vs_comment_builds_correct_sql( - self, _mock_disconnect, mock_execute_sql, mock_get_vs_table, mock_connect, make_database, make_vector_store - ): - """update_vs_comment should build correct COMMENT ON TABLE SQL.""" - mock_conn = MagicMock() - mock_connect.return_value = mock_conn - mock_get_vs_table.return_value = ("VS_MY_STORE", '{"alias": "my_alias", "model": "embed-3"}') - - db_details = make_database() - vector_store = make_vector_store(vector_store="VS_MY_STORE") - - utils_embed.update_vs_comment(vector_store=vector_store, db_details=db_details) - - call_args = mock_execute_sql.call_args[0] - sql = call_args[1] - assert "COMMENT ON TABLE VS_MY_STORE IS" in sql - assert "GENAI:" in sql - - @patch("server.api.utils.embed.utils_databases.connect") - @patch("server.api.utils.embed.functions.get_vs_table") - @patch("server.api.utils.embed.utils_databases.execute_sql") - @patch("server.api.utils.embed.utils_databases.disconnect") - def test_update_vs_comment_disconnects_on_success( - self, mock_disconnect, _mock_execute_sql, mock_get_vs_table, mock_connect, make_database, make_vector_store - ): - """update_vs_comment should disconnect from database after execution.""" - mock_conn = MagicMock() - mock_connect.return_value = mock_conn - mock_get_vs_table.return_value = ("VS_TEST", "{}") - - db_details = make_database() - vector_store = make_vector_store() - - utils_embed.update_vs_comment(vector_store=vector_store, db_details=db_details) - - mock_disconnect.assert_called_once_with(mock_conn) - - @patch("server.api.utils.embed.utils_databases.connect") - @patch("server.api.utils.embed.functions.get_vs_table") - @patch("server.api.utils.embed.utils_databases.execute_sql") - @patch("server.api.utils.embed.utils_databases.disconnect") - def test_update_vs_comment_calls_get_vs_table_with_correct_params( - self, _mock_disconnect, _mock_execute_sql, mock_get_vs_table, mock_connect, make_database, make_vector_store - ): - """update_vs_comment should call get_vs_table excluding database and vector_store.""" - mock_conn = MagicMock() - mock_connect.return_value = mock_conn - mock_get_vs_table.return_value = ("VS_TEST", "{}") - - db_details = make_database() - vector_store = make_vector_store( - vector_store="VS_TEST", - model="embed-model", - chunk_size=500, - chunk_overlap=100, - ) - - utils_embed.update_vs_comment(vector_store=vector_store, db_details=db_details) - - mock_get_vs_table.assert_called_once() - call_kwargs = mock_get_vs_table.call_args.kwargs - # Should NOT include database or vector_store - assert "database" not in call_kwargs - assert "vector_store" not in call_kwargs - # Should include other fields - assert "model" in call_kwargs or "chunk_size" in call_kwargs - - -class TestGetTempDirectory: - """Tests for the get_temp_directory function.""" - - @patch("server.api.utils.embed.Path") - def test_get_temp_directory_uses_app_tmp(self, mock_path): - """Should use /app/tmp if it exists.""" - mock_app_path = MagicMock() - mock_app_path.exists.return_value = True - mock_app_path.is_dir.return_value = True - mock_path.return_value = mock_app_path - mock_path.side_effect = lambda x: mock_app_path if x == "/app/tmp" else MagicMock() - - result = utils_embed.get_temp_directory("test_client", "embed") - - assert result is not None - - @patch("server.api.utils.embed.Path") - def test_get_temp_directory_uses_tmp_fallback(self, mock_path): - """Should use /tmp if /app/tmp doesn't exist.""" - mock_app_path = MagicMock() - mock_app_path.exists.return_value = False - mock_path.return_value = mock_app_path - - result = utils_embed.get_temp_directory("test_client", "embed") - - assert result is not None - - -class TestDocToJson: - """Tests for the doc_to_json function.""" - - def test_doc_to_json_creates_file(self, tmp_path): - """Should create JSON file from documents.""" - docs = [LangchainDocument(page_content="Test content", metadata={"source": "test.pdf"})] - - result = utils_embed.doc_to_json(docs, "test.pdf", str(tmp_path)) - - assert os.path.exists(result) - assert result.endswith(".json") - - -class TestProcessMetadata: - """Tests for the process_metadata function.""" - - def test_process_metadata_adds_metadata(self): - """Should add metadata to chunk.""" - chunk = LangchainDocument(page_content="Test content", metadata={"source": "/path/to/test.pdf", "page": 1}) - - result = utils_embed.process_metadata(1, chunk) - - assert len(result) == 1 - assert result[0].metadata["id"] == "test_1" - assert result[0].metadata["filename"] == "test.pdf" - - def test_process_metadata_includes_file_metadata(self): - """Should include file metadata if provided.""" - chunk = LangchainDocument(page_content="Test content", metadata={"source": "/path/to/doc.pdf"}) - file_metadata = {"doc.pdf": {"size": 1000, "time_modified": "2024-01-01", "etag": "abc123"}} - - result = utils_embed.process_metadata(1, chunk, file_metadata) - - assert result[0].metadata["size"] == 1000 - assert result[0].metadata["etag"] == "abc123" - - -class TestSplitDocument: - """Tests for the split_document function.""" - - def test_split_document_pdf(self): - """Should split PDF documents.""" - docs = [LangchainDocument(page_content="A" * 2000, metadata={"source": "test.pdf"})] - - result = utils_embed.split_document("default", 500, 50, docs, "pdf") - - assert len(result) > 0 - - def test_split_document_unsupported_extension(self): - """Should raise ValueError for unsupported extension.""" - docs = [LangchainDocument(page_content="Test", metadata={})] - - with pytest.raises(ValueError) as exc_info: - utils_embed.split_document("default", 500, 50, docs, "xyz") - - assert "Unsupported file type" in str(exc_info.value) - - -class TestGetDocumentLoader: # pylint: disable=protected-access - """Tests for the _get_document_loader function.""" - - def test_get_document_loader_pdf(self, tmp_path): - """Should return PyPDFLoader for PDF files.""" - test_file = tmp_path / "test.pdf" - test_file.touch() - - _, split = utils_embed._get_document_loader(str(test_file), "pdf") - - assert split is True - - def test_get_document_loader_html(self, tmp_path): - """Should return TextLoader for HTML files.""" - test_file = tmp_path / "test.html" - test_file.touch() - - _, split = utils_embed._get_document_loader(str(test_file), "html") - - assert split is True - - def test_get_document_loader_unsupported(self, tmp_path): - """Should raise ValueError for unsupported extension.""" - test_file = tmp_path / "test.xyz" - test_file.touch() - - with pytest.raises(ValueError): - utils_embed._get_document_loader(str(test_file), "xyz") - - -class TestCaptureFileMetadata: # pylint: disable=protected-access - """Tests for the _capture_file_metadata function.""" - - def test_capture_file_metadata_new_file(self, tmp_path): - """Should capture metadata for new files.""" - test_file = tmp_path / "test.txt" - test_file.write_text("content") - stat = test_file.stat() - file_metadata = {} - - utils_embed._capture_file_metadata("test.txt", stat, file_metadata) - - assert "test.txt" in file_metadata - assert "size" in file_metadata["test.txt"] - assert "time_modified" in file_metadata["test.txt"] - - def test_capture_file_metadata_existing_file(self, tmp_path): - """Should not overwrite existing metadata.""" - test_file = tmp_path / "test.txt" - test_file.write_text("content") - stat = test_file.stat() - file_metadata = {"test.txt": {"size": 9999}} - - utils_embed._capture_file_metadata("test.txt", stat, file_metadata) - - assert file_metadata["test.txt"]["size"] == 9999 # Not overwritten - - -class TestPrepareDocuments: # pylint: disable=protected-access - """Tests for the _prepare_documents function.""" - - def test_prepare_documents_removes_duplicates(self): - """Should remove duplicate documents.""" - docs = [ - LangchainDocument(page_content="Same content", metadata={}), - LangchainDocument(page_content="Same content", metadata={}), - LangchainDocument(page_content="Different content", metadata={}), - ] - - result = utils_embed._prepare_documents(docs) - - assert len(result) == 2 - - -class TestGetVectorStoreByAlias: - """Tests for the get_vector_store_by_alias function.""" - - @patch("server.api.utils.embed.utils_databases.connect") - @patch("server.api.utils.embed.utils_databases.disconnect") - def test_get_vector_store_by_alias_success(self, _mock_disconnect, mock_connect, make_database): - """Should return vector store config for matching alias.""" - mock_cursor = MagicMock() - mock_cursor.fetchall.return_value = [ - ("VS_TEST", '{"alias": "test_alias", "model": "embed-3", "chunk_size": 500, "chunk_overlap": 100}') - ] - mock_conn = MagicMock() - mock_conn.cursor.return_value = mock_cursor - mock_connect.return_value = mock_conn - - result = utils_embed.get_vector_store_by_alias(make_database(), "test_alias") - - assert result.vector_store == "VS_TEST" - assert result.alias == "test_alias" - - @patch("server.api.utils.embed.utils_databases.connect") - @patch("server.api.utils.embed.utils_databases.disconnect") - def test_get_vector_store_by_alias_not_found(self, _mock_disconnect, mock_connect, make_database): - """Should raise ValueError if alias not found.""" - mock_cursor = MagicMock() - mock_cursor.fetchall.return_value = [] - mock_conn = MagicMock() - mock_conn.cursor.return_value = mock_cursor - mock_connect.return_value = mock_conn - - with pytest.raises(ValueError) as exc_info: - utils_embed.get_vector_store_by_alias(make_database(), "nonexistent") - - assert "not found" in str(exc_info.value) - - -class TestGetTotalChunksCount: - """Tests for the get_total_chunks_count function.""" - - @patch("server.api.utils.embed.utils_databases.connect") - @patch("server.api.utils.embed.utils_databases.disconnect") - def test_get_total_chunks_count_success(self, _mock_disconnect, mock_connect, make_database): - """Should return chunk count.""" - mock_cursor = MagicMock() - mock_cursor.fetchone.return_value = (150,) - mock_conn = MagicMock() - mock_conn.cursor.return_value = mock_cursor - mock_connect.return_value = mock_conn - - result = utils_embed.get_total_chunks_count(make_database(), "VS_TEST") - - assert result == 150 - - @patch("server.api.utils.embed.utils_databases.connect") - @patch("server.api.utils.embed.utils_databases.disconnect") - def test_get_total_chunks_count_error(self, _mock_disconnect, mock_connect, make_database): - """Should return 0 on error.""" - mock_cursor = MagicMock() - mock_cursor.execute.side_effect = Exception("Query failed") - mock_conn = MagicMock() - mock_conn.cursor.return_value = mock_cursor - mock_connect.return_value = mock_conn - - result = utils_embed.get_total_chunks_count(make_database(), "VS_TEST") - - assert result == 0 - - -class TestGetProcessedObjectsMetadata: - """Tests for the get_processed_objects_metadata function.""" - - @patch("server.api.utils.embed.utils_databases.connect") - @patch("server.api.utils.embed.utils_databases.disconnect") - def test_get_processed_objects_metadata_new_format(self, _mock_disconnect, mock_connect, make_database): - """Should return metadata in new format.""" - mock_cursor = MagicMock() - mock_cursor.fetchall.return_value = [({"filename": "doc.pdf", "etag": "abc", "time_modified": "2024-01-01"},)] - mock_conn = MagicMock() - mock_conn.cursor.return_value = mock_cursor - mock_connect.return_value = mock_conn - - result = utils_embed.get_processed_objects_metadata(make_database(), "VS_TEST") - - assert "doc.pdf" in result - assert result["doc.pdf"]["etag"] == "abc" - - @patch("server.api.utils.embed.utils_databases.connect") - @patch("server.api.utils.embed.utils_databases.disconnect") - def test_get_processed_objects_metadata_old_format(self, _mock_disconnect, mock_connect, make_database): - """Should handle old format with source field.""" - mock_cursor = MagicMock() - mock_cursor.fetchall.return_value = [({"source": "/path/to/doc.pdf"},)] - mock_conn = MagicMock() - mock_conn.cursor.return_value = mock_cursor - mock_connect.return_value = mock_conn - - result = utils_embed.get_processed_objects_metadata(make_database(), "VS_TEST") - - assert "doc.pdf" in result - - -class TestGetVectorStoreFiles: - """Tests for the get_vector_store_files function.""" - - @patch("server.api.utils.embed.utils_databases.connect") - @patch("server.api.utils.embed.utils_databases.disconnect") - def test_get_vector_store_files_success(self, _mock_disconnect, mock_connect, make_database): - """Should return file list with statistics.""" - mock_cursor = MagicMock() - mock_cursor.fetchall.return_value = [ - ({"filename": "doc1.pdf", "size": 1000},), - ({"filename": "doc1.pdf", "size": 1000},), - ({"filename": "doc2.pdf", "size": 2000},), - ] - mock_conn = MagicMock() - mock_conn.cursor.return_value = mock_cursor - mock_connect.return_value = mock_conn - - result = utils_embed.get_vector_store_files(make_database(), "VS_TEST") - - assert result["total_files"] == 2 - assert result["total_chunks"] == 3 - - -class TestRefreshVectorStoreFromBucket: - """Tests for the refresh_vector_store_from_bucket function.""" - - @patch("server.api.utils.embed.get_temp_directory") - def test_refresh_vector_store_empty_objects( - self, _mock_get_temp, make_vector_store, make_database, make_oci_config - ): - """Should return early if no objects to process.""" - result = utils_embed.refresh_vector_store_from_bucket( - make_vector_store(), - "test-bucket", - [], - make_database(), - MagicMock(), - make_oci_config(), - ) - - assert result["processed_files"] == 0 - assert "No new or modified files" in result["message"] - - @patch("server.api.utils.embed.shutil.rmtree") - @patch("server.api.utils.embed.populate_vs") - @patch("server.api.utils.embed.load_and_split_documents") - @patch("server.api.utils.embed.utils_oci.get_object") - @patch("server.api.utils.embed.get_temp_directory") - def test_refresh_vector_store_success( - self, - mock_get_temp, - mock_get_object, - mock_load_split, - mock_populate, - _mock_rmtree, - make_vector_store, - make_database, - make_oci_config, - tmp_path, - ): - """Should process objects and populate vector store.""" - mock_get_temp.return_value = tmp_path - mock_get_object.return_value = str(tmp_path / "doc.pdf") - mock_load_split.return_value = ([LangchainDocument(page_content="test", metadata={})], []) - - bucket_objects = [{"name": "doc.pdf", "size": 1000, "time_modified": "2024-01-01", "etag": "abc"}] - - result = utils_embed.refresh_vector_store_from_bucket( - make_vector_store(), - "test-bucket", - bucket_objects, - make_database(), - MagicMock(), - make_oci_config(), - ) - - assert result["processed_files"] == 1 - mock_populate.assert_called_once() - - @patch("server.api.utils.embed.shutil.rmtree") - @patch("server.api.utils.embed.utils_oci.get_object") - @patch("server.api.utils.embed.get_temp_directory") - def test_refresh_vector_store_download_failure( - self, mock_get_temp, mock_get_object, _mock_rmtree, make_vector_store, make_database, make_oci_config, tmp_path - ): - """Should handle download failures gracefully.""" - mock_get_temp.return_value = tmp_path - mock_get_object.side_effect = Exception("Download failed") - - bucket_objects = [{"name": "doc.pdf", "size": 1000}] - - result = utils_embed.refresh_vector_store_from_bucket( - make_vector_store(), - "test-bucket", - bucket_objects, - make_database(), - MagicMock(), - make_oci_config(), - ) - - assert result["processed_files"] == 0 - assert "errors" in result - - -class TestLoadAndSplitDocuments: - """Tests for the load_and_split_documents function.""" - - @patch("server.api.utils.embed._get_document_loader") - @patch("server.api.utils.embed._process_and_split_document") - def test_load_and_split_documents_success(self, mock_process, mock_get_loader, tmp_path): - """Should load and split documents.""" - test_file = tmp_path / "test.txt" - test_file.write_text("Test content") - - mock_loader = MagicMock() - mock_loader.load.return_value = [LangchainDocument(page_content="Test", metadata={})] - mock_get_loader.return_value = (mock_loader, True) - mock_process.return_value = [LangchainDocument(page_content="Test", metadata={"id": "1"})] - - result, _ = utils_embed.load_and_split_documents([str(test_file)], "default", 500, 50) - - assert len(result) == 1 - - @patch("server.api.utils.embed._get_document_loader") - @patch("server.api.utils.embed._process_and_split_document") - @patch("server.api.utils.embed.doc_to_json") - def test_load_and_split_documents_with_json_output( - self, mock_doc_to_json, mock_process, mock_get_loader, tmp_path - ): - """Should write JSON when output_dir provided.""" - test_file = tmp_path / "test.txt" - test_file.write_text("Test content") - - mock_loader = MagicMock() - mock_loader.load.return_value = [LangchainDocument(page_content="Test", metadata={})] - mock_get_loader.return_value = (mock_loader, True) - mock_process.return_value = [LangchainDocument(page_content="Test", metadata={})] - mock_doc_to_json.return_value = str(tmp_path / "_test.json") - - _, split_files = utils_embed.load_and_split_documents( - [str(test_file)], "default", 500, 50, write_json=True, output_dir=str(tmp_path) - ) - - mock_doc_to_json.assert_called_once() - assert len(split_files) == 1 - - -class TestLoadAndSplitUrl: - """Tests for the load_and_split_url function.""" - - @patch("server.api.utils.embed.WebBaseLoader") - @patch("server.api.utils.embed.split_document") - def test_load_and_split_url_success(self, mock_split, mock_loader_class): - """Should load and split URL content.""" - mock_loader = MagicMock() - mock_loader.load.return_value = [ - LangchainDocument(page_content="Web content", metadata={"source": "http://example.com"}) - ] - mock_loader_class.return_value = mock_loader - mock_split.return_value = [LangchainDocument(page_content="Chunk", metadata={"source": "http://example.com"})] - - result, _ = utils_embed.load_and_split_url("default", "http://example.com", 500, 50) - - assert len(result) == 1 - - @patch("server.api.utils.embed.WebBaseLoader") - @patch("server.api.utils.embed.split_document") - def test_load_and_split_url_empty_content(self, mock_split, mock_loader_class): - """Should raise ValueError for empty content.""" - mock_loader = MagicMock() - mock_loader.load.return_value = [LangchainDocument(page_content="", metadata={})] - mock_loader_class.return_value = mock_loader - mock_split.return_value = [] - - with pytest.raises(ValueError) as exc_info: - utils_embed.load_and_split_url("default", "http://example.com", 500, 50) - - assert "no chunk-able data" in str(exc_info.value) - - -class TestJsonToDoc: # pylint: disable=protected-access - """Tests for the _json_to_doc function.""" - - def test_json_to_doc_success(self, tmp_path): - """Should convert JSON file to documents.""" - json_content = [ - {"kwargs": {"page_content": "Content 1", "metadata": {"source": "test.pdf"}}}, - {"kwargs": {"page_content": "Content 2", "metadata": {"source": "test.pdf"}}}, - ] - json_file = tmp_path / "test.json" - json_file.write_text(json.dumps(json_content)) - - result = utils_embed._json_to_doc(str(json_file)) - - assert len(result) == 2 - assert result[0].page_content == "Content 1" - - -class TestProcessAndSplitDocument: # pylint: disable=protected-access - """Tests for the _process_and_split_document function.""" - - @patch("server.api.utils.embed.split_document") - @patch("server.api.utils.embed.process_metadata") - def test_process_and_split_document_with_split(self, mock_process_meta, mock_split): - """Should split and process document.""" - mock_split.return_value = [LangchainDocument(page_content="Chunk", metadata={"source": "test.pdf"})] - mock_process_meta.return_value = [LangchainDocument(page_content="Chunk", metadata={"id": "1"})] - - loaded_doc = [LangchainDocument(page_content="Full content", metadata={})] - - result = utils_embed._process_and_split_document( - loaded_doc, - split=True, - model="default", - chunk_size=500, - chunk_overlap=50, - extension="pdf", - file_metadata={}, - ) - - mock_split.assert_called_once() - assert len(result) == 1 - - def test_process_and_split_document_no_split(self): - """Should return loaded doc without splitting.""" - loaded_doc = [LangchainDocument(page_content="Content", metadata={})] - - result = utils_embed._process_and_split_document( - loaded_doc, - split=False, - model="default", - chunk_size=500, - chunk_overlap=50, - extension="png", - file_metadata={}, - ) - - assert result == loaded_doc - - -class TestCreateTempVectorStore: # pylint: disable=protected-access - """Tests for the _create_temp_vector_store function.""" - - @patch("server.api.utils.embed.utils_databases.drop_vs") - @patch("server.api.utils.embed.OracleVS") - def test_create_temp_vector_store_success(self, mock_oracle_vs, mock_drop_vs, make_vector_store): - """Should create temporary vector store.""" - mock_vs = MagicMock() - mock_oracle_vs.return_value = mock_vs - mock_conn = MagicMock() - mock_embed_client = MagicMock() - vector_store = make_vector_store(vector_store="VS_TEST") - - _, vs_config_tmp = utils_embed._create_temp_vector_store(mock_conn, vector_store, mock_embed_client) - - assert vs_config_tmp.vector_store == "VS_TEST_TMP" - mock_drop_vs.assert_called_once() - - -class TestEmbedDocumentsInBatches: # pylint: disable=protected-access - """Tests for the _embed_documents_in_batches function.""" - - @patch("server.api.utils.embed.OracleVS.add_documents") - def test_embed_documents_in_batches_no_rate_limit(self, mock_add_docs): - """Should embed documents without rate limiting.""" - mock_vs = MagicMock() - chunks = [LangchainDocument(page_content=f"Chunk {i}", metadata={}) for i in range(10)] - - utils_embed._embed_documents_in_batches(mock_vs, chunks, rate_limit=0) - - mock_add_docs.assert_called_once() - - @patch("server.api.utils.embed.time.sleep") - @patch("server.api.utils.embed.OracleVS.add_documents") - def test_embed_documents_in_batches_with_rate_limit(self, mock_add_docs, mock_sleep): - """Should apply rate limiting between batches.""" - mock_vs = MagicMock() - # Create 600 chunks to trigger multiple batches (batch_size=500) - chunks = [LangchainDocument(page_content=f"Chunk {i}", metadata={}) for i in range(600)] - - utils_embed._embed_documents_in_batches(mock_vs, chunks, rate_limit=60) - - assert mock_add_docs.call_count == 2 # 500 + 100 - mock_sleep.assert_called() # Rate limiting applied - - -class TestMergeAndIndexVectorStore: # pylint: disable=protected-access - """Tests for the _merge_and_index_vector_store function.""" - - @patch("server.api.utils.embed.LangchainVS.create_index") - @patch("server.api.utils.embed.utils_databases.drop_vs") - @patch("server.api.utils.embed.utils_databases.execute_sql") - @patch("server.api.utils.embed.LangchainVS.drop_index_if_exists") - @patch("server.api.utils.embed.OracleVS") - def test_merge_and_index_vector_store_hnsw( - self, _mock_oracle_vs, mock_drop_idx, mock_execute, mock_drop_vs, mock_create_idx, make_vector_store - ): - """Should merge temp store and create HNSW index.""" - mock_conn = MagicMock() - vector_store = make_vector_store(vector_store="VS_TEST", index_type="HNSW") - vector_store_tmp = make_vector_store(vector_store="VS_TEST_TMP") - - utils_embed._merge_and_index_vector_store(mock_conn, vector_store, vector_store_tmp, MagicMock()) - - mock_drop_idx.assert_called_once() # HNSW drops existing index - mock_execute.assert_called_once() # Merge SQL - mock_drop_vs.assert_called_once() # Drop temp table - mock_create_idx.assert_called_once() # Create index - - -class TestPopulateVs: - """Tests for the populate_vs function.""" - - @patch("server.api.utils.embed.update_vs_comment") - @patch("server.api.utils.embed._merge_and_index_vector_store") - @patch("server.api.utils.embed._embed_documents_in_batches") - @patch("server.api.utils.embed._create_temp_vector_store") - @patch("server.api.utils.embed.utils_databases.connect") - @patch("server.api.utils.embed._prepare_documents") - def test_populate_vs_success( - self, - mock_prepare, - mock_connect, - mock_create_temp, - mock_embed, - mock_merge, - mock_comment, - make_vector_store, - make_database, - ): - """Should populate vector store with documents.""" - mock_prepare.return_value = [LangchainDocument(page_content="Test", metadata={})] - mock_conn = MagicMock() - mock_connect.return_value = mock_conn - mock_create_temp.return_value = (MagicMock(), make_vector_store(vector_store="VS_TMP")) - - docs = [LangchainDocument(page_content="Test", metadata={})] - - utils_embed.populate_vs(make_vector_store(), make_database(), MagicMock(), input_data=docs) - - mock_prepare.assert_called_once() - mock_create_temp.assert_called_once() - mock_embed.assert_called_once() - mock_merge.assert_called_once() - mock_comment.assert_called_once() - - -class TestSplitDocumentExtensions: - """Tests for split_document with various extensions.""" - - def test_split_document_html(self): - """Should split HTML documents using HTMLHeaderTextSplitter.""" - docs = [LangchainDocument(page_content="

Title

Content here

", metadata={"source": "test.html"})] - - result = utils_embed.split_document("default", 500, 50, docs, "html") - - assert len(result) >= 1 - - def test_split_document_md(self): - """Should split Markdown documents.""" - docs = [LangchainDocument(page_content="# Header\n\nContent " * 100, metadata={"source": "test.md"})] - - result = utils_embed.split_document("default", 500, 50, docs, "md") - - assert len(result) >= 1 - - def test_split_document_txt(self): - """Should split text documents.""" - docs = [LangchainDocument(page_content="Text content " * 200, metadata={"source": "test.txt"})] - - result = utils_embed.split_document("default", 500, 50, docs, "txt") - - assert len(result) >= 1 - - def test_split_document_csv(self): - """Should split CSV documents.""" - docs = [LangchainDocument(page_content="col1,col2\nval1,val2\n" * 100, metadata={"source": "test.csv"})] - - result = utils_embed.split_document("default", 500, 50, docs, "csv") - - assert len(result) >= 1 - - -class TestGetDocumentLoaderExtensions: # pylint: disable=protected-access - """Tests for _get_document_loader with various extensions.""" - - def test_get_document_loader_md(self, tmp_path): - """Should return TextLoader for Markdown files.""" - test_file = tmp_path / "test.md" - test_file.touch() - - _, split = utils_embed._get_document_loader(str(test_file), "md") - - assert split is True - - def test_get_document_loader_csv(self, tmp_path): - """Should return CSVLoader for CSV files.""" - test_file = tmp_path / "test.csv" - test_file.write_text("col1,col2\nval1,val2") - - _, split = utils_embed._get_document_loader(str(test_file), "csv") - - assert split is True - - def test_get_document_loader_txt(self, tmp_path): - """Should return TextLoader for text files.""" - test_file = tmp_path / "test.txt" - test_file.touch() - - _, split = utils_embed._get_document_loader(str(test_file), "txt") - - assert split is True - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(utils_embed, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert utils_embed.logger.name == "api.utils.embed" diff --git a/test/unit/server/api/utils/test_utils_mcp.py b/test/unit/server/api/utils/test_utils_mcp.py deleted file mode 100644 index 0301a321..00000000 --- a/test/unit/server/api/utils/test_utils_mcp.py +++ /dev/null @@ -1,192 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/utils/mcp.py -Tests for MCP utility functions. -""" - -from unittest.mock import patch, MagicMock, AsyncMock -import os -import pytest - -from server.api.utils import mcp - - -class TestGetClient: - """Tests for the get_client function.""" - - @patch.dict(os.environ, {"API_SERVER_KEY": "test-api-key"}) - def test_get_client_default_values(self): - """get_client should return default configuration.""" - result = mcp.get_client() - - assert "mcpServers" in result - assert "optimizer" in result["mcpServers"] - assert result["mcpServers"]["optimizer"]["type"] == "streamableHttp" - assert result["mcpServers"]["optimizer"]["transport"] == "streamable_http" - assert "http://127.0.0.1:8000/mcp/" in result["mcpServers"]["optimizer"]["url"] - - @patch.dict(os.environ, {"API_SERVER_KEY": "test-api-key"}) - def test_get_client_custom_server_port(self): - """get_client should use custom server and port.""" - result = mcp.get_client(server="http://custom.server", port=9000) - - assert "http://custom.server:9000/mcp/" in result["mcpServers"]["optimizer"]["url"] - - @patch.dict(os.environ, {"API_SERVER_KEY": "secret-key"}) - def test_get_client_includes_auth_header(self): - """get_client should include authorization header.""" - result = mcp.get_client() - - headers = result["mcpServers"]["optimizer"]["headers"] - assert "Authorization" in headers - assert headers["Authorization"] == "Bearer secret-key" - - @patch.dict(os.environ, {"API_SERVER_KEY": "test-key"}) - def test_get_client_langgraph_removes_type(self): - """get_client should remove type field for langgraph client.""" - result = mcp.get_client(client="langgraph") - - assert "type" not in result["mcpServers"]["optimizer"] - assert "transport" in result["mcpServers"]["optimizer"] - - @patch.dict(os.environ, {"API_SERVER_KEY": "test-key"}) - def test_get_client_non_langgraph_keeps_type(self): - """get_client should keep type field for non-langgraph clients.""" - result = mcp.get_client(client="other") - - assert "type" in result["mcpServers"]["optimizer"] - - @patch.dict(os.environ, {"API_SERVER_KEY": "test-key"}) - def test_get_client_none_client_keeps_type(self): - """get_client should keep type field when client is None.""" - result = mcp.get_client(client=None) - - assert "type" in result["mcpServers"]["optimizer"] - - @patch.dict(os.environ, {"API_SERVER_KEY": ""}) - def test_get_client_empty_api_key(self): - """get_client should handle empty API key.""" - result = mcp.get_client() - - headers = result["mcpServers"]["optimizer"]["headers"] - assert headers["Authorization"] == "Bearer " - - @patch.dict(os.environ, {"API_SERVER_KEY": "key"}) - def test_get_client_structure(self): - """get_client should return expected structure.""" - result = mcp.get_client() - - assert isinstance(result, dict) - assert isinstance(result["mcpServers"], dict) - assert isinstance(result["mcpServers"]["optimizer"], dict) - - optimizer = result["mcpServers"]["optimizer"] - expected_keys = {"type", "transport", "url", "headers"} - assert set(optimizer.keys()) == expected_keys - - -class TestListPrompts: - """Tests for the list_prompts function.""" - - @pytest.mark.asyncio - @patch("server.api.utils.mcp.Client") - async def test_list_prompts_success(self, mock_client_class): - """list_prompts should return list of prompts.""" - mock_prompts = [MagicMock(name="prompt1"), MagicMock(name="prompt2")] - - mock_client = MagicMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mock_client.list_prompts = AsyncMock(return_value=mock_prompts) - mock_client.close = AsyncMock() - mock_client_class.return_value = mock_client - - mock_mcp_engine = MagicMock() - - result = await mcp.list_prompts(mock_mcp_engine) - - assert result == mock_prompts - mock_client.list_prompts.assert_called_once() - - @pytest.mark.asyncio - @patch("server.api.utils.mcp.Client") - async def test_list_prompts_empty_list(self, mock_client_class): - """list_prompts should return empty list when no prompts.""" - mock_client = MagicMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mock_client.list_prompts = AsyncMock(return_value=[]) - mock_client.close = AsyncMock() - mock_client_class.return_value = mock_client - - mock_mcp_engine = MagicMock() - - result = await mcp.list_prompts(mock_mcp_engine) - - assert result == [] - - @pytest.mark.asyncio - @patch("server.api.utils.mcp.Client") - async def test_list_prompts_closes_client(self, mock_client_class): - """list_prompts should close client after use.""" - mock_client = MagicMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mock_client.list_prompts = AsyncMock(return_value=[]) - mock_client.close = AsyncMock() - mock_client_class.return_value = mock_client - - mock_mcp_engine = MagicMock() - - await mcp.list_prompts(mock_mcp_engine) - - mock_client.close.assert_called_once() - - @pytest.mark.asyncio - @patch("server.api.utils.mcp.Client") - async def test_list_prompts_creates_client_with_engine(self, mock_client_class): - """list_prompts should create client with MCP engine.""" - mock_client = MagicMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mock_client.list_prompts = AsyncMock(return_value=[]) - mock_client.close = AsyncMock() - mock_client_class.return_value = mock_client - - mock_mcp_engine = MagicMock() - - await mcp.list_prompts(mock_mcp_engine) - - mock_client_class.assert_called_once_with(mock_mcp_engine) - - @pytest.mark.asyncio - @patch("server.api.utils.mcp.Client") - async def test_list_prompts_closes_client_on_exception(self, mock_client_class): - """list_prompts should close client even if exception occurs.""" - mock_client = MagicMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mock_client.list_prompts = AsyncMock(side_effect=RuntimeError("Test error")) - mock_client.close = AsyncMock() - mock_client_class.return_value = mock_client - - mock_mcp_engine = MagicMock() - - with pytest.raises(RuntimeError): - await mcp.list_prompts(mock_mcp_engine) - - mock_client.close.assert_called_once() - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(mcp, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert mcp.logger.name == "api.utils.mcp" diff --git a/test/unit/server/api/utils/test_utils_models.py b/test/unit/server/api/utils/test_utils_models.py deleted file mode 100644 index 8616ca9e..00000000 --- a/test/unit/server/api/utils/test_utils_models.py +++ /dev/null @@ -1,433 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/utils/models.py -Tests for model utility functions. -""" - -# pylint: disable=too-few-public-methods - -from unittest.mock import patch, MagicMock -import pytest - -from server.api.utils import models as utils_models -from server.api.utils.models import ( - URLUnreachableError, - InvalidModelError, - ExistsModelError, - UnknownModelError, -) - - -class TestExceptions: - """Tests for custom exception classes.""" - - def test_url_unreachable_error_is_value_error(self): - """URLUnreachableError should inherit from ValueError.""" - exc = URLUnreachableError("URL unreachable") - assert isinstance(exc, ValueError) - - def test_invalid_model_error_is_value_error(self): - """InvalidModelError should inherit from ValueError.""" - exc = InvalidModelError("Invalid model") - assert isinstance(exc, ValueError) - - def test_exists_model_error_is_value_error(self): - """ExistsModelError should inherit from ValueError.""" - exc = ExistsModelError("Model exists") - assert isinstance(exc, ValueError) - - def test_unknown_model_error_is_value_error(self): - """UnknownModelError should inherit from ValueError.""" - exc = UnknownModelError("Model not found") - assert isinstance(exc, ValueError) - - -class TestCreate: - """Tests for the create function.""" - - @patch("server.api.utils.models.get") - @patch("server.api.utils.models.MODEL_OBJECTS", []) - def test_create_success(self, mock_get, make_model): - """create should add model to MODEL_OBJECTS.""" - model = make_model(model_id="gpt-4", provider="openai") - mock_get.side_effect = [UnknownModelError("Not found"), (model,)] - - result = utils_models.create(model) - - assert result == model - - @patch("server.api.utils.models.get") - def test_create_raises_exists_error(self, mock_get, make_model): - """create should raise ExistsModelError if model exists.""" - model = make_model(model_id="gpt-4", provider="openai") - mock_get.return_value = [model] - - with pytest.raises(ExistsModelError): - utils_models.create(model) - - @patch("server.api.utils.models.get") - @patch("server.api.utils.models.is_url_accessible") - @patch("server.api.utils.models.MODEL_OBJECTS", []) - def test_create_disables_model_if_url_inaccessible(self, mock_url_check, mock_get, make_model): - """create should disable model if API base URL is inaccessible.""" - model = make_model(model_id="custom", provider="openai") - model.api_base = "https://unreachable.example.com" - mock_get.side_effect = [UnknownModelError("Not found"), (model,)] - mock_url_check.return_value = (False, "Connection refused") - - result = utils_models.create(model, check_url=True) - - assert result.enabled is False - - -class TestGet: - """Tests for the get function.""" - - @patch("server.api.utils.models.MODEL_OBJECTS") - def test_get_all_models(self, mock_objects, make_model): - """get should return all models when no filters.""" - model1 = make_model(model_id="gpt-4", provider="openai") - model2 = make_model(model_id="claude-3", provider="anthropic") - mock_objects.__iter__ = lambda _: iter([model1, model2]) - mock_objects.__len__ = lambda _: 2 - - result = utils_models.get() - - assert len(result) == 2 - - @patch("server.api.utils.models.MODEL_OBJECTS") - def test_get_by_provider(self, mock_objects, make_model): - """get should filter by provider.""" - model1 = make_model(model_id="gpt-4", provider="openai") - model2 = make_model(model_id="claude-3", provider="anthropic") - mock_objects.__iter__ = lambda _: iter([model1, model2]) - mock_objects.__len__ = lambda _: 2 - - result = utils_models.get(model_provider="openai") - - assert len(result) == 1 - assert result[0].provider == "openai" - - @patch("server.api.utils.models.MODEL_OBJECTS") - def test_get_by_type(self, mock_objects, make_model): - """get should filter by type.""" - model1 = make_model(model_id="gpt-4", model_type="ll") - model2 = make_model(model_id="embed-3", model_type="embed") - mock_objects.__iter__ = lambda _: iter([model1, model2]) - mock_objects.__len__ = lambda _: 2 - - result = utils_models.get(model_type="embed") - - assert len(result) == 1 - assert result[0].type == "embed" - - @patch("server.api.utils.models.MODEL_OBJECTS") - def test_get_exclude_disabled(self, mock_objects, make_model): - """get should exclude disabled models when include_disabled=False.""" - model1 = make_model(model_id="gpt-4", enabled=True) - model2 = make_model(model_id="gpt-3", enabled=False) - mock_objects.__iter__ = lambda _: iter([model1, model2]) - mock_objects.__len__ = lambda _: 2 - - result = utils_models.get(include_disabled=False) - - assert len(result) == 1 - assert result[0].enabled is True - - @patch("server.api.utils.models.MODEL_OBJECTS") - def test_get_raises_unknown_error(self, mock_objects): - """get should raise UnknownModelError if model_id not found.""" - mock_objects.__iter__ = lambda _: iter([]) - mock_objects.__len__ = lambda _: 0 - - with pytest.raises(UnknownModelError): - utils_models.get(model_id="nonexistent") - - -class TestUpdate: - """Tests for the update function.""" - - @patch("server.api.utils.models.get") - @patch("server.api.utils.models.is_url_accessible") - def test_update_success(self, mock_url_check, mock_get, make_model): - """update should update model in place.""" - existing_model = make_model(model_id="gpt-4", provider="openai") - mock_get.return_value = (existing_model,) - mock_url_check.return_value = (True, "OK") - - payload = make_model(model_id="gpt-4", provider="openai") - payload.temperature = 0.9 - - result = utils_models.update(payload) - - assert result.temperature == 0.9 - - @patch("server.api.utils.models.get") - @patch("server.api.utils.models.is_url_accessible") - def test_update_raises_url_unreachable(self, mock_url_check, mock_get, make_model): - """update should raise URLUnreachableError if URL inaccessible.""" - existing_model = make_model(model_id="gpt-4", provider="openai") - mock_get.return_value = (existing_model,) - mock_url_check.return_value = (False, "Connection refused") - - payload = make_model(model_id="gpt-4", provider="openai", enabled=True) - payload.api_base = "https://unreachable.example.com" - - with pytest.raises(URLUnreachableError): - utils_models.update(payload) - - -class TestDelete: - """Tests for the delete function.""" - - def test_delete_removes_model(self, make_model): - """delete should remove model from MODEL_OBJECTS.""" - model1 = make_model(model_id="gpt-4", provider="openai") - model2 = make_model(model_id="claude-3", provider="anthropic") - - with patch("server.api.utils.models.MODEL_OBJECTS", [model1, model2]) as mock_objects: - utils_models.delete("openai", "gpt-4") - assert len(mock_objects) == 1 - assert mock_objects[0].id == "claude-3" - - -class TestGetSupported: - """Tests for the get_supported function.""" - - @patch("server.api.utils.models.litellm") - def test_get_supported_returns_providers(self, mock_litellm): - """get_supported should return list of providers.""" - mock_provider = MagicMock() - mock_provider.value = "openai" - mock_litellm.provider_list = [mock_provider] - mock_litellm.models_by_provider = {"openai": ["gpt-4"]} - mock_litellm.get_model_info.return_value = {"mode": "chat", "key": "gpt-4"} - mock_litellm.get_llm_provider.return_value = ("openai", None, None, "https://api.openai.com/v1") - - result = utils_models.get_supported() - - assert len(result) >= 1 - assert result[0]["provider"] == "openai" - - @patch("server.api.utils.models.litellm") - def test_get_supported_filters_by_provider(self, mock_litellm): - """get_supported should filter by provider.""" - mock_provider1 = MagicMock() - mock_provider1.value = "openai" - mock_provider2 = MagicMock() - mock_provider2.value = "anthropic" - mock_litellm.provider_list = [mock_provider1, mock_provider2] - mock_litellm.models_by_provider = {"openai": [], "anthropic": []} - - result = utils_models.get_supported(model_provider="anthropic") - - assert len(result) == 1 - assert result[0]["provider"] == "anthropic" - - -class TestCreateGenai: - """Tests for the create_genai function.""" - - @patch("server.api.utils.models.utils_oci.get_genai_models") - @patch("server.api.utils.models.get") - @patch("server.api.utils.models.delete") - @patch("server.api.utils.models.create") - def test_create_genai_creates_models(self, mock_create, _mock_delete, mock_get, mock_get_genai, make_oci_config): - """create_genai should create GenAI models.""" - mock_get_genai.return_value = [ - {"model_name": "cohere.command-r", "capabilities": ["CHAT"]}, - {"model_name": "cohere.embed-v3", "capabilities": ["TEXT_EMBEDDINGS"]}, - ] - mock_get.return_value = [] - - config = make_oci_config(genai_region="us-chicago-1") - config.genai_compartment_id = "ocid1.compartment.oc1..test" - - utils_models.create_genai(config) - - assert mock_create.call_count == 2 - - @patch("server.api.utils.models.utils_oci.get_genai_models") - def test_create_genai_returns_empty_when_no_models(self, mock_get_genai, make_oci_config): - """create_genai should return empty list when no models.""" - mock_get_genai.return_value = [] - - config = make_oci_config(genai_region="us-chicago-1") - - result = utils_models.create_genai(config) - - assert not result - - -class TestGetFullConfig: # pylint: disable=protected-access - """Tests for the _get_full_config function.""" - - @patch("server.api.utils.models.get") - def test_get_full_config_success(self, mock_get, make_model): - """_get_full_config should merge model config with defined model.""" - defined_model = make_model(model_id="gpt-4", provider="openai") - defined_model.api_base = "https://api.openai.com/v1" - mock_get.return_value = (defined_model,) - - model_config = {"model": "openai/gpt-4", "temperature": 0.9} - - full_config, provider = utils_models._get_full_config(model_config, None) - - assert provider == "openai" - assert full_config["temperature"] == 0.9 - assert full_config["api_base"] == "https://api.openai.com/v1" - - @patch("server.api.utils.models.get") - def test_get_full_config_raises_unknown_model(self, mock_get): - """_get_full_config should raise UnknownModelError if not found.""" - mock_get.side_effect = UnknownModelError("Model not found") - - model_config = {"model": "openai/nonexistent"} - - with pytest.raises(UnknownModelError): - utils_models._get_full_config(model_config, None) - - -class TestGetLitellmConfig: - """Tests for the get_litellm_config function.""" - - @patch("server.api.utils.models._get_full_config") - @patch("server.api.utils.models.litellm.get_supported_openai_params") - def test_get_litellm_config_basic(self, mock_get_params, mock_get_full): - """get_litellm_config should return LiteLLM config.""" - mock_get_full.return_value = ( - {"model": "openai/gpt-4", "temperature": 0.7, "api_base": "https://api.openai.com/v1"}, - "openai", - ) - mock_get_params.return_value = ["temperature", "max_tokens"] - - model_config = {"model": "openai/gpt-4"} - - result = utils_models.get_litellm_config(model_config, None) - - assert result["model"] == "openai/gpt-4" - assert result["drop_params"] is True - - @patch("server.api.utils.models._get_full_config") - @patch("server.api.utils.models.litellm.get_supported_openai_params") - @patch("server.api.utils.models.utils_oci.get_signer") - def test_get_litellm_config_oci_provider(self, mock_get_signer, mock_get_params, mock_get_full, make_oci_config): - """get_litellm_config should include OCI params for OCI provider.""" - mock_get_full.return_value = ( - { - "model": "oci/cohere.command-r", - "api_base": "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", - }, - "oci", - ) - mock_get_params.return_value = ["temperature"] - mock_get_signer.return_value = None # API key auth - - oci_config = make_oci_config(genai_region="us-chicago-1") - oci_config.genai_compartment_id = "ocid1.compartment.oc1..test" - oci_config.tenancy = "test-tenancy" - oci_config.user = "test-user" - oci_config.fingerprint = "test-fingerprint" - oci_config.key_file = "/path/to/key" - - model_config = {"model": "oci/cohere.command-r"} - - result = utils_models.get_litellm_config(model_config, oci_config) - - assert result["oci_region"] == "us-chicago-1" - assert result["oci_compartment_id"] == "ocid1.compartment.oc1..test" - - -class TestGetClientEmbed: - """Tests for the get_client_embed function.""" - - @patch("server.api.utils.models._get_full_config") - @patch("server.api.utils.models.utils_oci.init_genai_client") - @patch("server.api.utils.models.OCIGenAIEmbeddings") - def test_get_client_embed_oci(self, mock_embeddings, mock_init_client, mock_get_full, make_oci_config): - """get_client_embed should return OCIGenAIEmbeddings for OCI provider.""" - mock_get_full.return_value = ({"id": "cohere.embed-v3"}, "oci") - mock_init_client.return_value = MagicMock() - mock_embeddings.return_value = MagicMock() - - oci_config = make_oci_config() - oci_config.genai_compartment_id = "ocid1.compartment.oc1..test" - - model_config = {"model": "oci/cohere.embed-v3"} - - utils_models.get_client_embed(model_config, oci_config) - - mock_embeddings.assert_called_once() - - @patch("server.api.utils.models._get_full_config") - @patch("server.api.utils.models.init_embeddings") - def test_get_client_embed_openai(self, mock_init_embeddings, mock_get_full, make_oci_config): - """get_client_embed should use init_embeddings for non-OCI providers.""" - mock_get_full.return_value = ( - {"id": "text-embedding-3-small", "api_base": "https://api.openai.com/v1"}, - "openai", - ) - mock_init_embeddings.return_value = MagicMock() - - oci_config = make_oci_config() - model_config = {"model": "openai/text-embedding-3-small"} - - utils_models.get_client_embed(model_config, oci_config) - - mock_init_embeddings.assert_called_once() - - -class TestProcessModelEntry: # pylint: disable=protected-access - """Tests for the _process_model_entry function.""" - - @patch("server.api.utils.models.litellm") - def test_process_model_entry_success(self, mock_litellm): - """_process_model_entry should return model dict.""" - mock_litellm.get_model_info.return_value = {"mode": "chat", "key": "gpt-4"} - mock_litellm.get_llm_provider.return_value = ("openai", None, None, "https://api.openai.com/v1") - - type_to_modes = {"ll": {"chat"}} - allowed_modes = {"chat"} - - result = utils_models._process_model_entry("gpt-4", type_to_modes, allowed_modes, "openai") - - assert result is not None - assert result["type"] == "ll" - - @patch("server.api.utils.models.litellm") - def test_process_model_entry_filters_mode(self, mock_litellm): - """_process_model_entry should return None for unsupported modes.""" - mock_litellm.get_model_info.return_value = {"mode": "moderation"} - - type_to_modes = {"ll": {"chat"}} - allowed_modes = {"chat"} - - result = utils_models._process_model_entry("mod-model", type_to_modes, allowed_modes, "openai") - - assert result is None - - @patch("server.api.utils.models.litellm") - def test_process_model_entry_handles_exception(self, mock_litellm): - """_process_model_entry should handle exceptions gracefully.""" - mock_litellm.get_model_info.side_effect = Exception("API error") - - type_to_modes = {"ll": {"chat"}} - allowed_modes = {"chat"} - - result = utils_models._process_model_entry("bad-model", type_to_modes, allowed_modes, "openai") - - assert result == {"key": "bad-model"} - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(utils_models, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert utils_models.logger.name == "api.utils.models" diff --git a/test/unit/server/api/utils/test_utils_oci.py b/test/unit/server/api/utils/test_utils_oci.py deleted file mode 100644 index 3e6d9f2b..00000000 --- a/test/unit/server/api/utils/test_utils_oci.py +++ /dev/null @@ -1,595 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/utils/oci.py -Tests for OCI utility functions. -""" - -# pylint: disable=too-few-public-methods - -from datetime import datetime -from unittest.mock import patch, MagicMock - -import pytest -import oci - -from server.api.utils import oci as utils_oci -from server.api.utils.oci import OciException - - -class TestOciException: - """Tests for OciException class.""" - - def test_oci_exception_init(self): - """OciException should store status_code and detail.""" - exc = OciException(status_code=404, detail="Not found") - assert exc.status_code == 404 - assert exc.detail == "Not found" - - def test_oci_exception_message(self): - """OciException should use detail as message.""" - exc = OciException(status_code=500, detail="Server error") - assert str(exc) == "Server error" - - -class TestGet: - """Tests for the get function.""" - - @patch("server.api.utils.oci.bootstrap.OCI_OBJECTS", []) - def test_get_raises_value_error_when_not_configured(self): - """get should raise ValueError when no OCI objects configured.""" - with pytest.raises(ValueError) as exc_info: - utils_oci.get() - assert "not configured" in str(exc_info.value) - - @patch("server.api.utils.oci.bootstrap.OCI_OBJECTS") - def test_get_returns_all_oci_objects(self, mock_objects, make_oci_config): - """get should return all OCI objects when no filters.""" - oci1 = make_oci_config(auth_profile="PROFILE1") - oci2 = make_oci_config(auth_profile="PROFILE2") - mock_objects.__iter__ = lambda _: iter([oci1, oci2]) - mock_objects.__len__ = lambda _: 2 - mock_objects.__bool__ = lambda _: True - - result = utils_oci.get() - - assert len(result) == 2 - - @patch("server.api.utils.oci.bootstrap.OCI_OBJECTS") - def test_get_by_auth_profile(self, mock_objects, make_oci_config): - """get should return matching OCI object by auth_profile.""" - oci1 = make_oci_config(auth_profile="PROFILE1") - oci2 = make_oci_config(auth_profile="PROFILE2") - mock_objects.__iter__ = lambda _: iter([oci1, oci2]) - - result = utils_oci.get(auth_profile="PROFILE1") - - assert result.auth_profile == "PROFILE1" - - @patch("server.api.utils.oci.bootstrap.OCI_OBJECTS") - def test_get_raises_value_error_profile_not_found(self, mock_objects, make_oci_config): - """get should raise ValueError when profile not found.""" - mock_objects.__iter__ = lambda _: iter([make_oci_config(auth_profile="DEFAULT")]) - - with pytest.raises(ValueError) as exc_info: - utils_oci.get(auth_profile="NONEXISTENT") - - assert "not found" in str(exc_info.value) - - def test_get_raises_value_error_both_params(self): - """get should raise ValueError when both client and auth_profile provided.""" - with pytest.raises(ValueError) as exc_info: - utils_oci.get(client="test", auth_profile="DEFAULT") - - assert "not both" in str(exc_info.value) - - @patch("server.api.utils.oci.bootstrap.SETTINGS_OBJECTS") - @patch("server.api.utils.oci.bootstrap.OCI_OBJECTS") - def test_get_by_client(self, mock_oci, mock_settings, make_oci_config, make_settings): - """get should return OCI object based on client settings.""" - settings = make_settings(client="test_client") - settings.oci.auth_profile = "CLIENT_PROFILE" - mock_settings.__iter__ = lambda _: iter([settings]) - mock_settings.__len__ = lambda _: 1 - - oci_config = make_oci_config(auth_profile="CLIENT_PROFILE") - mock_oci.__iter__ = lambda _: iter([oci_config]) - - result = utils_oci.get(client="test_client") - - assert result.auth_profile == "CLIENT_PROFILE" - - @patch("server.api.utils.oci.bootstrap.SETTINGS_OBJECTS", []) - def test_get_raises_value_error_client_not_found(self): - """get should raise ValueError when client not found.""" - with pytest.raises(ValueError) as exc_info: - utils_oci.get(client="nonexistent") - - assert "not found" in str(exc_info.value) - - -class TestGetSigner: - """Tests for the get_signer function.""" - - @patch("server.api.utils.oci.oci.auth.signers.InstancePrincipalsSecurityTokenSigner") - def test_get_signer_instance_principal(self, mock_signer_class, make_oci_config): - """get_signer should return instance principal signer.""" - mock_signer = MagicMock() - mock_signer_class.return_value = mock_signer - config = make_oci_config() - config.authentication = "instance_principal" - - result = utils_oci.get_signer(config) - - assert result == mock_signer - mock_signer_class.assert_called_once() - - @patch("server.api.utils.oci.oci.auth.signers.get_oke_workload_identity_resource_principal_signer") - def test_get_signer_oke_workload_identity(self, mock_signer_func, make_oci_config): - """get_signer should return OKE workload identity signer.""" - mock_signer = MagicMock() - mock_signer_func.return_value = mock_signer - config = make_oci_config() - config.authentication = "oke_workload_identity" - - result = utils_oci.get_signer(config) - - assert result == mock_signer - - def test_get_signer_api_key_returns_none(self, make_oci_config): - """get_signer should return None for API key authentication.""" - config = make_oci_config() - config.authentication = "api_key" - - result = utils_oci.get_signer(config) - - assert result is None - - -class TestInitClient: - """Tests for the init_client function.""" - - @patch("server.api.utils.oci.get_signer") - @patch("server.api.utils.oci.oci.object_storage.ObjectStorageClient") - def test_init_client_standard_auth(self, mock_client_class, mock_get_signer, make_oci_config): - """init_client should initialize with standard authentication.""" - mock_get_signer.return_value = None - mock_client = MagicMock() - mock_client_class.return_value = mock_client - config = make_oci_config() - - result = utils_oci.init_client(oci.object_storage.ObjectStorageClient, config) - - assert result == mock_client - - @patch("server.api.utils.oci.get_signer") - @patch("server.api.utils.oci.oci.object_storage.ObjectStorageClient") - def test_init_client_with_signer(self, mock_client_class, mock_get_signer, make_oci_config): - """init_client should use signer when provided.""" - mock_signer = MagicMock() - mock_signer.tenancy_id = "test-tenancy-id" - mock_get_signer.return_value = mock_signer - mock_client = MagicMock() - mock_client_class.return_value = mock_client - config = make_oci_config() - config.authentication = "instance_principal" - config.region = "us-ashburn-1" # Required for signer-based auth - config.tenancy = "existing-tenancy" # Set tenancy so code doesn't try to derive from signer - - result = utils_oci.init_client(oci.object_storage.ObjectStorageClient, config) - - assert result == mock_client - # Check signer was passed to client - call_kwargs = mock_client_class.call_args.kwargs - assert call_kwargs["signer"] == mock_signer - - @patch("server.api.utils.oci.get_signer") - def test_init_client_raises_oci_exception_on_invalid_config(self, mock_get_signer, make_oci_config): - """init_client should raise OciException on invalid config.""" - mock_get_signer.return_value = None - config = make_oci_config() - - with patch("server.api.utils.oci.oci.object_storage.ObjectStorageClient") as mock_client: - mock_client.side_effect = oci.exceptions.InvalidConfig("Invalid configuration") - - with pytest.raises(OciException) as exc_info: - utils_oci.init_client(oci.object_storage.ObjectStorageClient, config) - - assert exc_info.value.status_code == 400 - - @patch("server.api.utils.oci.get_signer") - @patch("server.api.utils.oci.oci.generative_ai_inference.GenerativeAiInferenceClient") - def test_init_client_genai_sets_service_endpoint(self, mock_client_class, mock_get_signer, make_oci_config): - """init_client should set service endpoint for GenAI client.""" - mock_get_signer.return_value = None - mock_client = MagicMock() - mock_client_class.return_value = mock_client - config = make_oci_config(genai_region="us-chicago-1") - config.genai_compartment_id = "ocid1.compartment.oc1..test" - - utils_oci.init_client(oci.generative_ai_inference.GenerativeAiInferenceClient, config) - - call_kwargs = mock_client_class.call_args.kwargs - assert "inference.generativeai.us-chicago-1.oci.oraclecloud.com" in call_kwargs["service_endpoint"] - - -class TestGetNamespace: - """Tests for the get_namespace function.""" - - @patch("server.api.utils.oci.init_client") - def test_get_namespace_success(self, mock_init_client, make_oci_config): - """get_namespace should return namespace on success.""" - mock_client = MagicMock() - mock_client.get_namespace.return_value.data = "test-namespace" - mock_init_client.return_value = mock_client - config = make_oci_config() - - result = utils_oci.get_namespace(config) - - assert result == "test-namespace" - assert config.namespace == "test-namespace" - - @patch("server.api.utils.oci.init_client") - def test_get_namespace_raises_on_service_error(self, mock_init_client, make_oci_config): - """get_namespace should raise OciException on service error.""" - mock_client = MagicMock() - mock_client.get_namespace.side_effect = oci.exceptions.ServiceError( - status=401, code="NotAuthenticated", headers={}, message="Not authenticated" - ) - mock_init_client.return_value = mock_client - config = make_oci_config() - - with pytest.raises(OciException) as exc_info: - utils_oci.get_namespace(config) - - assert exc_info.value.status_code == 401 - - @patch("server.api.utils.oci.init_client") - def test_get_namespace_raises_on_file_not_found(self, mock_init_client, make_oci_config): - """get_namespace should raise OciException on file not found.""" - mock_init_client.side_effect = FileNotFoundError("Key file not found") - config = make_oci_config() - - with pytest.raises(OciException) as exc_info: - utils_oci.get_namespace(config) - - assert exc_info.value.status_code == 400 - - -class TestGetRegions: - """Tests for the get_regions function.""" - - @patch("server.api.utils.oci.init_client") - def test_get_regions_returns_list(self, mock_init_client, make_oci_config): - """get_regions should return list of region subscriptions.""" - mock_region = MagicMock() - mock_region.is_home_region = True - mock_region.region_key = "IAD" - mock_region.region_name = "us-ashburn-1" - mock_region.status = "READY" - - mock_client = MagicMock() - mock_client.list_region_subscriptions.return_value.data = [mock_region] - mock_init_client.return_value = mock_client - config = make_oci_config() - config.tenancy = "test-tenancy" - - result = utils_oci.get_regions(config) - - assert len(result) == 1 - assert result[0]["region_name"] == "us-ashburn-1" - assert result[0]["is_home_region"] is True - - -class TestGetGenaiModels: - """Tests for the get_genai_models function.""" - - def test_get_genai_models_raises_without_compartment(self, make_oci_config): - """get_genai_models should raise OciException without compartment_id.""" - config = make_oci_config() - config.genai_compartment_id = None - - with pytest.raises(OciException) as exc_info: - utils_oci.get_genai_models(config) - - assert exc_info.value.status_code == 400 - assert "genai_compartment_id" in exc_info.value.detail - - def test_get_genai_models_regional_raises_without_region(self, make_oci_config): - """get_genai_models should raise OciException without region when regional=True.""" - config = make_oci_config() - config.genai_compartment_id = "ocid1.compartment.oc1..test" - config.genai_region = None - - with pytest.raises(OciException) as exc_info: - utils_oci.get_genai_models(config, regional=True) - - assert exc_info.value.status_code == 400 - assert "genai_region" in exc_info.value.detail - - @patch("server.api.utils.oci.init_client") - def test_get_genai_models_returns_models(self, mock_init_client, make_oci_config): - """get_genai_models should return list of GenAI models.""" - mock_model = MagicMock() - mock_model.display_name = "cohere.command-r-plus" - mock_model.capabilities = ["TEXT_GENERATION"] - mock_model.vendor = "cohere" - mock_model.id = "ocid1.model.oc1..test" - mock_model.time_deprecated = None - mock_model.time_dedicated_retired = None - mock_model.time_on_demand_retired = None - - mock_response = MagicMock() - mock_response.data.items = [mock_model] - - mock_client = MagicMock() - mock_client.list_models.return_value = mock_response - mock_init_client.return_value = mock_client - - config = make_oci_config(genai_region="us-chicago-1") - config.genai_compartment_id = "ocid1.compartment.oc1..test" - - result = utils_oci.get_genai_models(config, regional=True) - - assert len(result) == 1 - assert result[0]["model_name"] == "cohere.command-r-plus" - - -class TestGetCompartments: - """Tests for the get_compartments function.""" - - @patch("server.api.utils.oci.init_client") - def test_get_compartments_returns_dict(self, mock_init_client, make_oci_config): - """get_compartments should return dict of compartment paths.""" - mock_compartment = MagicMock() - mock_compartment.id = "ocid1.compartment.oc1..test" - mock_compartment.name = "TestCompartment" - mock_compartment.compartment_id = None # Root level - - mock_client = MagicMock() - mock_client.list_compartments.return_value.data = [mock_compartment] - mock_init_client.return_value = mock_client - - config = make_oci_config() - config.tenancy = "test-tenancy" - - result = utils_oci.get_compartments(config) - - assert "TestCompartment" in result - assert result["TestCompartment"] == "ocid1.compartment.oc1..test" - - -class TestGetBuckets: - """Tests for the get_buckets function.""" - - @patch("server.api.utils.oci.init_client") - def test_get_buckets_returns_list(self, mock_init_client, make_oci_config): - """get_buckets should return list of bucket names.""" - mock_bucket = MagicMock() - mock_bucket.name = "test-bucket" - mock_bucket.freeform_tags = {} - - mock_client = MagicMock() - mock_client.list_buckets.return_value.data = [mock_bucket] - mock_init_client.return_value = mock_client - - config = make_oci_config() - config.namespace = "test-namespace" - - result = utils_oci.get_buckets("compartment-id", config) - - assert result == ["test-bucket"] - - @patch("server.api.utils.oci.init_client") - def test_get_buckets_excludes_genai_chunk_buckets(self, mock_init_client, make_oci_config): - """get_buckets should exclude buckets with genai_chunk=true tag.""" - mock_bucket1 = MagicMock() - mock_bucket1.name = "normal-bucket" - mock_bucket1.freeform_tags = {} - - mock_bucket2 = MagicMock() - mock_bucket2.name = "chunk-bucket" - mock_bucket2.freeform_tags = {"genai_chunk": "true"} - - mock_client = MagicMock() - mock_client.list_buckets.return_value.data = [mock_bucket1, mock_bucket2] - mock_init_client.return_value = mock_client - - config = make_oci_config() - config.namespace = "test-namespace" - - result = utils_oci.get_buckets("compartment-id", config) - - assert result == ["normal-bucket"] - - @patch("server.api.utils.oci.init_client") - def test_get_buckets_raises_on_service_error(self, mock_init_client, make_oci_config): - """get_buckets should raise OciException on service error.""" - mock_client = MagicMock() - mock_client.list_buckets.side_effect = oci.exceptions.ServiceError( - status=401, code="NotAuthenticated", headers={}, message="Not authenticated" - ) - mock_init_client.return_value = mock_client - - config = make_oci_config() - config.namespace = "test-namespace" - - with pytest.raises(OciException) as exc_info: - utils_oci.get_buckets("compartment-id", config) - - assert exc_info.value.status_code == 401 - - -class TestGetBucketObjects: - """Tests for the get_bucket_objects function.""" - - @patch("server.api.utils.oci.init_client") - def test_get_bucket_objects_returns_names(self, mock_init_client, make_oci_config): - """get_bucket_objects should return list of object names.""" - mock_obj = MagicMock() - mock_obj.name = "document.pdf" - - mock_response = MagicMock() - mock_response.data.objects = [mock_obj] - - mock_client = MagicMock() - mock_client.list_objects.return_value = mock_response - mock_init_client.return_value = mock_client - - config = make_oci_config() - config.namespace = "test-namespace" - - result = utils_oci.get_bucket_objects("test-bucket", config) - - assert result == ["document.pdf"] - - @patch("server.api.utils.oci.init_client") - def test_get_bucket_objects_returns_empty_on_not_found(self, mock_init_client, make_oci_config): - """get_bucket_objects should return empty list on service error.""" - mock_client = MagicMock() - mock_client.list_objects.side_effect = oci.exceptions.ServiceError( - status=404, code="BucketNotFound", headers={}, message="Bucket not found" - ) - mock_init_client.return_value = mock_client - - config = make_oci_config() - config.namespace = "test-namespace" - - result = utils_oci.get_bucket_objects("nonexistent-bucket", config) - - assert result == [] - - -class TestGetBucketObjectsWithMetadata: - """Tests for the get_bucket_objects_with_metadata function.""" - - @patch("server.api.utils.oci.init_client") - def test_get_bucket_objects_with_metadata_returns_supported_files(self, mock_init_client, make_oci_config): - """get_bucket_objects_with_metadata should return only supported file types.""" - mock_pdf = MagicMock() - mock_pdf.name = "document.pdf" - mock_pdf.size = 1000 - mock_pdf.etag = "abc123" - mock_pdf.time_modified = datetime(2024, 1, 1, 12, 0, 0) - mock_pdf.md5 = "md5hash" - - mock_exe = MagicMock() - mock_exe.name = "program.exe" - mock_exe.size = 2000 - mock_exe.etag = "def456" - mock_exe.time_modified = datetime(2024, 1, 1, 12, 0, 0) - mock_exe.md5 = "md5hash2" - - mock_response = MagicMock() - mock_response.data.objects = [mock_pdf, mock_exe] - - mock_client = MagicMock() - mock_client.list_objects.return_value = mock_response - mock_init_client.return_value = mock_client - - config = make_oci_config() - config.namespace = "test-namespace" - - result = utils_oci.get_bucket_objects_with_metadata("test-bucket", config) - - assert len(result) == 1 - assert result[0]["name"] == "document.pdf" - assert result[0]["extension"] == "pdf" - - -class TestDetectChangedObjects: - """Tests for the detect_changed_objects function.""" - - def test_detect_new_objects(self): - """detect_changed_objects should identify new objects.""" - current_objects = [{"name": "new_file.pdf", "etag": "abc123", "time_modified": "2024-01-01T12:00:00"}] - processed_objects = {} - - new, modified = utils_oci.detect_changed_objects(current_objects, processed_objects) - - assert len(new) == 1 - assert len(modified) == 0 - assert new[0]["name"] == "new_file.pdf" - - def test_detect_modified_objects(self): - """detect_changed_objects should identify modified objects.""" - current_objects = [{"name": "existing.pdf", "etag": "new_etag", "time_modified": "2024-01-02T12:00:00"}] - processed_objects = {"existing.pdf": {"etag": "old_etag", "time_modified": "2024-01-01T12:00:00"}} - - new, modified = utils_oci.detect_changed_objects(current_objects, processed_objects) - - assert len(new) == 0 - assert len(modified) == 1 - assert modified[0]["name"] == "existing.pdf" - - def test_detect_unchanged_objects(self): - """detect_changed_objects should not flag unchanged objects.""" - current_objects = [{"name": "existing.pdf", "etag": "same_etag", "time_modified": "2024-01-01T12:00:00"}] - processed_objects = {"existing.pdf": {"etag": "same_etag", "time_modified": "2024-01-01T12:00:00"}} - - new, modified = utils_oci.detect_changed_objects(current_objects, processed_objects) - - assert len(new) == 0 - assert len(modified) == 0 - - def test_detect_skips_old_format_metadata(self): - """detect_changed_objects should skip objects with old format metadata.""" - current_objects = [{"name": "old_format.pdf", "etag": "new_etag", "time_modified": "2024-01-02T12:00:00"}] - processed_objects = {"old_format.pdf": {"etag": None, "time_modified": None}} - - new, modified = utils_oci.detect_changed_objects(current_objects, processed_objects) - - assert len(new) == 0 - assert len(modified) == 0 - - -class TestGetObject: - """Tests for the get_object function.""" - - @patch("server.api.utils.oci.init_client") - def test_get_object_downloads_file(self, mock_init_client, make_oci_config, tmp_path): - """get_object should download file to directory.""" - mock_response = MagicMock() - mock_response.data.raw.stream.return_value = [b"file content"] - - mock_client = MagicMock() - mock_client.get_object.return_value = mock_response - mock_init_client.return_value = mock_client - - config = make_oci_config() - config.namespace = "test-namespace" - - result = utils_oci.get_object(str(tmp_path), "folder/document.pdf", "test-bucket", config) - - assert result == str(tmp_path / "document.pdf") - assert (tmp_path / "document.pdf").exists() - assert (tmp_path / "document.pdf").read_bytes() == b"file content" - - -class TestInitGenaiClient: - """Tests for the init_genai_client function.""" - - @patch("server.api.utils.oci.init_client") - def test_init_genai_client_calls_init_client(self, mock_init_client, make_oci_config): - """init_genai_client should call init_client with correct type.""" - mock_client = MagicMock() - mock_init_client.return_value = mock_client - config = make_oci_config() - - result = utils_oci.init_genai_client(config) - - mock_init_client.assert_called_once_with(oci.generative_ai_inference.GenerativeAiInferenceClient, config) - assert result == mock_client - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(utils_oci, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert utils_oci.logger.name == "api.utils.oci" diff --git a/test/unit/server/api/utils/test_utils_settings.py b/test/unit/server/api/utils/test_utils_settings.py deleted file mode 100644 index 1f84ec41..00000000 --- a/test/unit/server/api/utils/test_utils_settings.py +++ /dev/null @@ -1,352 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/utils/settings.py -Tests for settings utility functions. -""" - -# pylint: disable=too-few-public-methods - -import json -from unittest.mock import patch, MagicMock - -import pytest - -from server.api.utils import settings as utils_settings -from server.api.utils.settings import bootstrap - - -class TestCreateClient: - """Tests for the create_client function.""" - - @patch("server.api.utils.settings.bootstrap.SETTINGS_OBJECTS") - def test_create_client_success(self, mock_settings, make_settings): - """create_client should create new client from default settings.""" - default_settings = make_settings(client="default") - # Return new iterator each time __iter__ is called (consumed twice: any() and next()) - mock_settings.__iter__ = lambda _: iter([default_settings]) - mock_settings.__bool__ = lambda _: True - mock_settings.append = MagicMock() - - result = utils_settings.create_client("new_client") - - assert result.client == "new_client" - mock_settings.append.assert_called_once() - - @patch("server.api.utils.settings.bootstrap.SETTINGS_OBJECTS") - def test_create_client_raises_on_existing(self, mock_settings, make_settings): - """create_client should raise ValueError if client exists.""" - existing_settings = make_settings(client="existing") - mock_settings.__iter__ = lambda _: iter([existing_settings]) - - with pytest.raises(ValueError) as exc_info: - utils_settings.create_client("existing") - - assert "already exists" in str(exc_info.value) - - -class TestGetClient: - """Tests for the get_client function.""" - - @patch("server.api.utils.settings.bootstrap.SETTINGS_OBJECTS") - def test_get_client_success(self, mock_settings, make_settings): - """get_client should return client settings.""" - client_settings = make_settings(client="test_client") - mock_settings.__iter__ = lambda _: iter([client_settings]) - - result = utils_settings.get_client("test_client") - - assert result.client == "test_client" - - @patch("server.api.utils.settings.bootstrap.SETTINGS_OBJECTS") - def test_get_client_raises_on_not_found(self, mock_settings): - """get_client should raise ValueError if client not found.""" - mock_settings.__iter__ = lambda _: iter([]) - - with pytest.raises(ValueError) as exc_info: - utils_settings.get_client("nonexistent") - - assert "not found" in str(exc_info.value) - - -class TestUpdateClient: - """Tests for the update_client function.""" - - @patch("server.api.utils.settings.get_client") - @patch("server.api.utils.settings.bootstrap.SETTINGS_OBJECTS") - def test_update_client_success(self, mock_settings, mock_get_client, make_settings): - """update_client should update and return client settings.""" - old_settings = make_settings(client="test_client") - new_settings = make_settings(client="other") - - mock_get_client.side_effect = [old_settings, new_settings] - mock_settings.remove = MagicMock() - mock_settings.append = MagicMock() - - utils_settings.update_client(new_settings, "test_client") - - mock_settings.remove.assert_called_once_with(old_settings) - mock_settings.append.assert_called_once() - - -class TestGetMcpPromptsWithOverrides: - """Tests for the get_mcp_prompts_with_overrides function.""" - - @pytest.mark.asyncio - @patch("server.api.utils.settings.utils_mcp.list_prompts") - @patch("server.api.utils.settings.defaults") - @patch("server.api.utils.settings.cache.get_override") - async def test_get_mcp_prompts_with_overrides_success(self, mock_get_override, mock_defaults, mock_list_prompts): - """get_mcp_prompts_with_overrides should return list of MCPPrompt.""" - mock_prompt = MagicMock() - mock_prompt.name = "optimizer_test-prompt" - mock_prompt.title = "Test Prompt" - mock_prompt.description = "Test description" - mock_prompt.meta = {"_fastmcp": {"tags": ["rag", "chat"]}} - - mock_list_prompts.return_value = [mock_prompt] - - mock_default_func = MagicMock() - mock_default_func.return_value.content.text = "Default text" - mock_defaults.optimizer_test_prompt = mock_default_func - - mock_get_override.return_value = None - - mock_mcp_engine = MagicMock() - - result = await utils_settings.get_mcp_prompts_with_overrides(mock_mcp_engine) - - assert len(result) == 1 - assert result[0].name == "optimizer_test-prompt" - assert result[0].text == "Default text" - - @pytest.mark.asyncio - @patch("server.api.utils.settings.utils_mcp.list_prompts") - @patch("server.api.utils.settings.defaults") - @patch("server.api.utils.settings.cache.get_override") - async def test_get_mcp_prompts_uses_override(self, mock_get_override, mock_defaults, mock_list_prompts): - """get_mcp_prompts_with_overrides should use override text when available.""" - mock_prompt = MagicMock() - mock_prompt.name = "optimizer_test-prompt" - mock_prompt.title = None - mock_prompt.description = None - mock_prompt.meta = None - - mock_list_prompts.return_value = [mock_prompt] - - mock_default_func = MagicMock() - mock_default_func.return_value.content.text = "Default text" - mock_defaults.optimizer_test_prompt = mock_default_func - - mock_get_override.return_value = "Override text" - - mock_mcp_engine = MagicMock() - - result = await utils_settings.get_mcp_prompts_with_overrides(mock_mcp_engine) - - assert result[0].text == "Override text" - - @pytest.mark.asyncio - @patch("server.api.utils.settings.utils_mcp.list_prompts") - async def test_get_mcp_prompts_filters_non_optimizer(self, mock_list_prompts): - """get_mcp_prompts_with_overrides should filter out non-optimizer prompts.""" - mock_prompt1 = MagicMock() - mock_prompt1.name = "optimizer_test" - mock_prompt1.title = None - mock_prompt1.description = None - mock_prompt1.meta = None - - mock_prompt2 = MagicMock() - mock_prompt2.name = "other_prompt" - - mock_list_prompts.return_value = [mock_prompt1, mock_prompt2] - - mock_mcp_engine = MagicMock() - - with patch("server.api.utils.settings.defaults") as mock_defaults: - mock_defaults.optimizer_test = None - with patch("server.api.utils.settings.cache.get_override", return_value=None): - result = await utils_settings.get_mcp_prompts_with_overrides(mock_mcp_engine) - - assert len(result) == 1 - assert result[0].name == "optimizer_test" - - -class TestGetServer: - """Tests for the get_server function.""" - - @pytest.mark.asyncio - @patch("server.api.utils.settings.get_mcp_prompts_with_overrides") - @patch("server.api.utils.settings.bootstrap.DATABASE_OBJECTS", []) - @patch("server.api.utils.settings.bootstrap.MODEL_OBJECTS", []) - @patch("server.api.utils.settings.bootstrap.OCI_OBJECTS", []) - async def test_get_server_returns_config(self, mock_get_prompts): - """get_server should return server configuration dict.""" - mock_get_prompts.return_value = [] - mock_mcp_engine = MagicMock() - - result = await utils_settings.get_server(mock_mcp_engine) - - assert "database_configs" in result - assert "model_configs" in result - assert "oci_configs" in result - assert "prompt_configs" in result - - -class TestUpdateServer: - """Tests for the update_server function.""" - - @patch("server.api.utils.settings.bootstrap.DATABASE_OBJECTS", []) - @patch("server.api.utils.settings.bootstrap.MODEL_OBJECTS", []) - @patch("server.api.utils.settings.bootstrap.OCI_OBJECTS", []) - def test_update_server_updates_databases(self, make_database, make_settings): - """update_server should update database objects.""" - config_data = { - "client_settings": make_settings().model_dump(), - "database_configs": [make_database(name="NEW_DB").model_dump()], - } - - utils_settings.update_server(config_data) - - assert len(bootstrap.DATABASE_OBJECTS) == 1 - - @patch("server.api.utils.settings._load_prompt_configs") - @patch("server.api.utils.settings.bootstrap.DATABASE_OBJECTS", []) - @patch("server.api.utils.settings.bootstrap.MODEL_OBJECTS", []) - @patch("server.api.utils.settings.bootstrap.OCI_OBJECTS", []) - def test_update_server_loads_prompt_configs(self, mock_load_prompts, make_settings): - """update_server should load prompt configs.""" - config_data = { - "client_settings": make_settings().model_dump(), - "prompt_configs": [{"name": "test", "title": "Test Title", "text": "Test text"}], - } - - utils_settings.update_server(config_data) - - mock_load_prompts.assert_called_once_with(config_data) - - -class TestLoadPromptOverride: # pylint: disable=protected-access - """Tests for the _load_prompt_override function.""" - - @patch("server.api.utils.settings.cache.set_override") - def test_load_prompt_override_with_text(self, mock_set_override): - """_load_prompt_override should set cache with text.""" - prompt = {"name": "test_prompt", "text": "Test text"} - - result = utils_settings._load_prompt_override(prompt) - - assert result is True - mock_set_override.assert_called_once_with("test_prompt", "Test text") - - @patch("server.api.utils.settings.cache.set_override") - def test_load_prompt_override_without_text(self, mock_set_override): - """_load_prompt_override should return False without text.""" - prompt = {"name": "test_prompt"} - - result = utils_settings._load_prompt_override(prompt) - - assert result is False - mock_set_override.assert_not_called() - - -class TestLoadPromptConfigs: # pylint: disable=protected-access - """Tests for the _load_prompt_configs function.""" - - @patch("server.api.utils.settings._load_prompt_override") - def test_load_prompt_configs_with_prompts(self, mock_load_override): - """_load_prompt_configs should load all prompts.""" - mock_load_override.return_value = True - config_data = {"prompt_configs": [{"name": "p1", "text": "t1"}, {"name": "p2", "text": "t2"}]} - - utils_settings._load_prompt_configs(config_data) - - assert mock_load_override.call_count == 2 - - @patch("server.api.utils.settings._load_prompt_override") - def test_load_prompt_configs_without_key(self, mock_load_override): - """_load_prompt_configs should handle missing prompt_configs key.""" - config_data = {} - - utils_settings._load_prompt_configs(config_data) - - mock_load_override.assert_not_called() - - @patch("server.api.utils.settings._load_prompt_override") - def test_load_prompt_configs_empty_list(self, mock_load_override): - """_load_prompt_configs should handle empty prompt_configs.""" - config_data = {"prompt_configs": []} - - utils_settings._load_prompt_configs(config_data) - - mock_load_override.assert_not_called() - - -class TestLoadConfigFromJsonData: - """Tests for the load_config_from_json_data function.""" - - @patch("server.api.utils.settings.update_server") - @patch("server.api.utils.settings.update_client") - def test_load_config_from_json_data_with_client(self, mock_update_client, mock_update_server, make_settings): - """load_config_from_json_data should update specific client.""" - config_data = {"client_settings": make_settings().model_dump()} - - utils_settings.load_config_from_json_data(config_data, client="test_client") - - mock_update_server.assert_called_once() - mock_update_client.assert_called_once() - - @patch("server.api.utils.settings.update_server") - @patch("server.api.utils.settings.update_client") - def test_load_config_from_json_data_without_client(self, mock_update_client, mock_update_server, make_settings): - """load_config_from_json_data should update server and default when no client.""" - config_data = {"client_settings": make_settings().model_dump()} - - utils_settings.load_config_from_json_data(config_data, client=None) - - mock_update_server.assert_called_once() - assert mock_update_client.call_count == 2 # "server" and "default" - - @patch("server.api.utils.settings.update_server") - def test_load_config_from_json_data_raises_missing_settings(self, _mock_update_server): - """load_config_from_json_data should raise KeyError if missing client_settings.""" - config_data = {} - - with pytest.raises(KeyError) as exc_info: - utils_settings.load_config_from_json_data(config_data) - - assert "client_settings" in str(exc_info.value) - - -class TestReadConfigFromJsonFile: - """Tests for the read_config_from_json_file function.""" - - @patch.dict("os.environ", {"CONFIG_FILE": "/path/to/config.json"}) - @patch("os.path.isfile", return_value=True) - @patch("os.access", return_value=True) - @patch("builtins.open") - def test_read_config_from_json_file_success(self, mock_open, mock_access, mock_isfile, make_settings): - """read_config_from_json_file should return Configuration.""" - _ = (mock_access, mock_isfile) # Used to suppress unused argument warning - - config_data = {"client_settings": make_settings().model_dump()} - mock_open.return_value.__enter__.return_value.read.return_value = json.dumps(config_data) - - # Mock json.load - with patch("json.load", return_value=config_data): - result = utils_settings.read_config_from_json_file() - - assert result is not None - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(utils_settings, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert utils_settings.logger.name == "api.core.settings" diff --git a/test/unit/server/api/utils/test_utils_testbed.py b/test/unit/server/api/utils/test_utils_testbed.py deleted file mode 100644 index f68ab797..00000000 --- a/test/unit/server/api/utils/test_utils_testbed.py +++ /dev/null @@ -1,324 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/utils/testbed.py -Tests for testbed utility functions. - -Uses hybrid approach: -- Real Oracle database for testbed table creation and querying -- Mocks for external dependencies (PDF processing, LLM calls) -""" - -# pylint: disable=too-few-public-methods - -import json -from unittest.mock import patch, MagicMock - -import pytest - -from server.api.utils import testbed as utils_testbed - - -class TestJsonlToJsonContent: - """Tests for the jsonl_to_json_content function.""" - - def test_jsonl_to_json_content_single_json(self): - """Should parse single JSON object.""" - content = '{"question": "What is AI?", "answer": "Artificial Intelligence"}' - - result = utils_testbed.jsonl_to_json_content(content) - - parsed = json.loads(result) - assert parsed["question"] == "What is AI?" - - def test_jsonl_to_json_content_jsonl(self): - """Should parse JSONL (multiple lines).""" - content = '{"q": "Q1"}\n{"q": "Q2"}' - - result = utils_testbed.jsonl_to_json_content(content) - - parsed = json.loads(result) - assert len(parsed) == 2 - - def test_jsonl_to_json_content_bytes(self): - """Should handle bytes input.""" - content = b'{"question": "test"}' - - result = utils_testbed.jsonl_to_json_content(content) - - parsed = json.loads(result) - assert parsed["question"] == "test" - - def test_jsonl_to_json_content_single_jsonl(self): - """Should handle single line JSONL.""" - content = '{"question": "test"}\n' - - result = utils_testbed.jsonl_to_json_content(content) - - parsed = json.loads(result) - assert parsed["question"] == "test" - - def test_jsonl_to_json_content_invalid(self): - """Should raise ValueError for invalid content.""" - content = "not valid json at all" - - with pytest.raises(ValueError) as exc_info: - utils_testbed.jsonl_to_json_content(content) - - assert "Invalid JSONL content" in str(exc_info.value) - - -class TestCreateTestsetObjects: - """Tests for the create_testset_objects function. - - Uses mocks since DDL (CREATE TABLE) causes implicit commits in Oracle, - which breaks savepoint-based test isolation. - """ - - @patch("server.api.utils.testbed.utils_databases.execute_sql") - def test_create_testset_objects_executes_ddl(self, mock_execute): - """Should execute SQL to create testset tables.""" - mock_conn = MagicMock() - - utils_testbed.create_testset_objects(mock_conn) - - # Should execute 3 DDL statements (testsets, testset_qa, evaluations) - assert mock_execute.call_count == 3 - - -class TestGetTestsets: - """Tests for the get_testsets function. - - Uses mocks since the function may trigger DDL which causes implicit commits. - """ - - @patch("server.api.utils.testbed.utils_databases.execute_sql") - def test_get_testsets_returns_list(self, mock_execute): - """Should return list of TestSets.""" - mock_conn = MagicMock() - # Return empty result set - mock_execute.return_value = [] - - result = utils_testbed.get_testsets(mock_conn) - - assert isinstance(result, list) - assert len(result) == 0 - - @patch("server.api.utils.testbed.utils_databases.execute_sql") - def test_get_testsets_creates_tables_on_first_call(self, mock_execute): - """Should create tables if they don't exist.""" - mock_conn = MagicMock() - # First call returns None (which causes TypeError during unpacking), - # then 3 DDL calls for table creation, then final query returns [] - mock_execute.side_effect = [None, None, None, None, []] - - result = utils_testbed.get_testsets(mock_conn) - - assert isinstance(result, list) - - -class TestGetTestsetQa: - """Tests for the get_testset_qa function.""" - - @patch("server.api.utils.testbed.utils_databases.execute_sql") - def test_get_testset_qa_returns_qa(self, mock_execute): - """Should return TestSetQA object.""" - mock_execute.return_value = [('{"question": "Q1"}',)] - mock_conn = MagicMock() - - result = utils_testbed.get_testset_qa(mock_conn, "abc123") - - assert len(result.qa_data) == 1 - - -class TestGetEvaluations: - """Tests for the get_evaluations function.""" - - @patch("server.api.utils.testbed.utils_databases.execute_sql") - def test_get_evaluations_returns_list(self, mock_execute): - """Should return list of Evaluation objects.""" - mock_eid = MagicMock() - mock_eid.hex.return_value = "eval123" - mock_execute.return_value = [(mock_eid, "2024-01-01", 0.85)] - mock_conn = MagicMock() - - result = utils_testbed.get_evaluations(mock_conn, "tid123") - - assert len(result) == 1 - assert result[0].correctness == 0.85 - - @patch("server.api.utils.testbed.utils_databases.execute_sql") - @patch("server.api.utils.testbed.create_testset_objects") - def test_get_evaluations_creates_tables_on_error(self, mock_create, mock_execute): - """Should create tables if TypeError occurs.""" - mock_execute.return_value = None - mock_conn = MagicMock() - - result = utils_testbed.get_evaluations(mock_conn, "tid123") - - mock_create.assert_called_once() - assert result == [] - - -class TestDeleteQa: - """Tests for the delete_qa function.""" - - @patch("server.api.utils.testbed.utils_databases.execute_sql") - def test_delete_qa_executes_sql(self, mock_execute): - """Should execute DELETE SQL.""" - mock_conn = MagicMock() - - utils_testbed.delete_qa(mock_conn, "tid123") - - mock_execute.assert_called_once() - mock_conn.commit.assert_called_once() - - -class TestUpsertQa: - """Tests for the upsert_qa function.""" - - @patch("server.api.utils.testbed.utils_databases.execute_sql") - def test_upsert_qa_single_qa(self, mock_execute): - """Should handle single QA object.""" - mock_execute.return_value = "tid123" - mock_conn = MagicMock() - json_data = '{"question": "Q1", "answer": "A1"}' - - result = utils_testbed.upsert_qa(mock_conn, "TestSet", "2024-01-01T00:00:00.000", json_data) - - mock_execute.assert_called_once() - assert result == "tid123" - - @patch("server.api.utils.testbed.utils_databases.execute_sql") - def test_upsert_qa_multiple_qa(self, mock_execute): - """Should handle multiple QA objects.""" - mock_execute.return_value = "tid123" - mock_conn = MagicMock() - json_data = '[{"q": "Q1"}, {"q": "Q2"}]' - - utils_testbed.upsert_qa(mock_conn, "TestSet", "2024-01-01T00:00:00.000", json_data) - - mock_execute.assert_called_once() - - -class TestInsertEvaluation: - """Tests for the insert_evaluation function.""" - - @patch("server.api.utils.testbed.utils_databases.execute_sql") - def test_insert_evaluation_executes_sql(self, mock_execute): - """Should execute INSERT SQL.""" - mock_execute.return_value = "eid123" - mock_conn = MagicMock() - - result = utils_testbed.insert_evaluation( - mock_conn, "tid123", "2024-01-01T00:00:00.000", 0.85, '{"model": "gpt-4"}', b"report_data" - ) - - mock_execute.assert_called_once() - assert result == "eid123" - - -class TestLoadAndSplit: - """Tests for the load_and_split function.""" - - @patch("server.api.utils.testbed.PdfReader") - @patch("server.api.utils.testbed.SentenceSplitter") - def test_load_and_split_processes_pdf(self, mock_splitter, mock_reader): - """Should load PDF and split into nodes.""" - mock_page = MagicMock() - mock_page.extract_text.return_value = "Page content" - mock_reader.return_value.pages = [mock_page] - - mock_splitter_instance = MagicMock() - mock_splitter_instance.return_value = ["node1", "node2"] - mock_splitter.return_value = mock_splitter_instance - - utils_testbed.load_and_split("/path/to/doc.pdf", chunk_size=1024) - - mock_reader.assert_called_once_with("/path/to/doc.pdf") - mock_splitter.assert_called_once_with(chunk_size=1024) - - -class TestBuildKnowledgeBase: - """Tests for the build_knowledge_base function.""" - - @patch("server.api.utils.testbed.utils_models.get_litellm_config") - @patch("server.api.utils.testbed.set_llm_model") - @patch("server.api.utils.testbed.set_embedding_model") - @patch("server.api.utils.testbed.KnowledgeBase") - @patch("server.api.utils.testbed.generate_testset") - def test_build_knowledge_base_success( - self, mock_generate, mock_kb, mock_set_embed, mock_set_llm, mock_get_config, make_oci_config - ): - """Should create knowledge base and generate testset.""" - mock_get_config.return_value = {"api_key": "test"} - mock_testset = MagicMock() - mock_generate.return_value = mock_testset - - mock_text_node = MagicMock() - mock_text_node.text = "Sample text" - text_nodes = [mock_text_node] - - oci_config = make_oci_config() - - result = utils_testbed.build_knowledge_base( - text_nodes, - questions=5, - ll_model="openai/gpt-4", - embed_model="openai/text-embedding-3-small", - oci_config=oci_config, - ) - - mock_set_llm.assert_called_once() - mock_set_embed.assert_called_once() - mock_kb.assert_called_once() - mock_generate.assert_called_once() - assert result == mock_testset - - -class TestProcessReport: - """Tests for the process_report function.""" - - @patch("server.api.utils.testbed.utils_databases.execute_sql") - @patch("server.api.utils.testbed.pickle.loads") - def test_process_report_success(self, mock_pickle, mock_execute, make_settings): - """Should process evaluation report.""" - mock_eid = MagicMock() - mock_eid.hex.return_value = "eid123" - - mock_report = MagicMock() - mock_report.to_pandas.return_value = MagicMock(to_dict=MagicMock(return_value={})) - mock_report.correctness_by_topic.return_value = MagicMock(to_dict=MagicMock(return_value={})) - mock_report.failures = MagicMock(to_dict=MagicMock(return_value={})) - mock_pickle.return_value = mock_report - - # Settings needs to be a valid Settings object (or dict with required fields) - settings_data = make_settings().model_dump() - mock_execute.return_value = [ - { - "EID": mock_eid, - "EVALUATED": "2024-01-01", - "CORRECTNESS": 0.85, - "SETTINGS": settings_data, - "RAG_REPORT": b"data", - } - ] - mock_conn = MagicMock() - - result = utils_testbed.process_report(mock_conn, "eid123") - - assert result.eid == "eid123" - assert result.correctness == 0.85 - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(utils_testbed, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert utils_testbed.logger.name == "api.utils.testbed" diff --git a/test/unit/server/api/utils/test_utils_testbed_metrics.py b/test/unit/server/api/utils/test_utils_testbed_metrics.py deleted file mode 100644 index 4431f4e5..00000000 --- a/test/unit/server/api/utils/test_utils_testbed_metrics.py +++ /dev/null @@ -1,345 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/utils/testbed_metrics.py -Tests for custom testbed evaluation metrics. -""" - -# pylint: disable=too-few-public-methods,protected-access - -from unittest.mock import patch, MagicMock - -import pytest - -from giskard.llm.errors import LLMGenerationError - -from server.api.utils import testbed_metrics - - -class TestFormatConversation: - """Tests for the format_conversation function.""" - - def test_format_conversation_single_message(self): - """Should format single message correctly.""" - conversation = [{"role": "user", "content": "Hello"}] - - result = testbed_metrics.format_conversation(conversation) - - assert result == "Hello" - - def test_format_conversation_multiple_messages(self): - """Should format multiple messages with double newlines.""" - conversation = [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there"}, - ] - - result = testbed_metrics.format_conversation(conversation) - - assert "Hello" in result - assert "Hi there" in result - assert "\n\n" in result - - def test_format_conversation_lowercases_role(self): - """Should lowercase role names in tags.""" - conversation = [{"role": "USER", "content": "Test"}] - - result = testbed_metrics.format_conversation(conversation) - - assert result == "Test" - - def test_format_conversation_empty_list(self): - """Should return empty string for empty conversation.""" - result = testbed_metrics.format_conversation([]) - - assert result == "" - - def test_format_conversation_preserves_content(self): - """Should preserve message content including special characters.""" - conversation = [{"role": "user", "content": "What is 2 + 2?\nIs it 4?"}] - - result = testbed_metrics.format_conversation(conversation) - - assert "What is 2 + 2?\nIs it 4?" in result - - -class TestCorrectnessInputTemplate: - """Tests for the CORRECTNESS_INPUT_TEMPLATE constant.""" - - def test_template_contains_placeholders(self): - """Template should contain all required placeholders.""" - template = testbed_metrics.CORRECTNESS_INPUT_TEMPLATE - - assert "{description}" in template - assert "{conversation}" in template - assert "{answer}" in template - assert "{reference_answer}" in template - - def test_template_format_works(self): - """Template should be formattable with all placeholders.""" - result = testbed_metrics.CORRECTNESS_INPUT_TEMPLATE.format( - description="Test agent", - conversation="Hello", - answer="Hi there", - reference_answer="Hello back", - ) - - assert "Test agent" in result - assert "Hello" in result - assert "Hi there" in result - assert "Hello back" in result - - -class TestCustomCorrectnessMetricInit: - """Tests for CustomCorrectnessMetric initialization.""" - - def test_init_with_required_params(self): - """Should initialize with required parameters.""" - metric = testbed_metrics.CustomCorrectnessMetric( - name="correctness", - system_prompt="You are a judge.", - ) - - assert metric.system_prompt == "You are a judge." - assert metric.agent_description == "A chatbot answering questions." - - def test_init_with_custom_agent_description(self): - """Should accept custom agent description.""" - metric = testbed_metrics.CustomCorrectnessMetric( - name="correctness", - system_prompt="You are a judge.", - agent_description="A specialized Q&A bot.", - ) - - assert metric.agent_description == "A specialized Q&A bot." - - def test_init_with_llm_client(self): - """Should accept custom LLM client.""" - mock_client = MagicMock() - - metric = testbed_metrics.CustomCorrectnessMetric( - name="correctness", - system_prompt="You are a judge.", - llm_client=mock_client, - ) - - assert metric._llm_client == mock_client - - -class TestCustomCorrectnessMetricCall: - """Tests for CustomCorrectnessMetric __call__ method.""" - - @patch("server.api.utils.testbed_metrics.get_default_client") - @patch("server.api.utils.testbed_metrics.parse_json_output") - def test_call_returns_correctness_result(self, mock_parse, mock_get_client): - """Should return correctness evaluation result.""" - mock_client = MagicMock() - mock_client.complete.return_value = MagicMock(content='{"correctness": true}') - mock_get_client.return_value = mock_client - mock_parse.return_value = {"correctness": True} - - metric = testbed_metrics.CustomCorrectnessMetric( - name="correctness", - system_prompt="You are a judge.", - ) - - mock_sample = MagicMock() - mock_sample.conversation_history = [] - mock_sample.question = "What is AI?" - mock_sample.reference_answer = "Artificial Intelligence" - - mock_answer = MagicMock() - mock_answer.message = "AI stands for Artificial Intelligence" - - result = metric(mock_sample, mock_answer) - - assert result == {"correctness": True} - mock_client.complete.assert_called_once() - - @patch("server.api.utils.testbed_metrics.get_default_client") - @patch("server.api.utils.testbed_metrics.parse_json_output") - def test_call_strips_reason_when_correct(self, mock_parse, mock_get_client): - """Should strip correctness_reason when answer is correct.""" - mock_client = MagicMock() - mock_client.complete.return_value = MagicMock(content='{}') - mock_get_client.return_value = mock_client - mock_parse.return_value = {"correctness": True, "correctness_reason": "Matches exactly"} - - metric = testbed_metrics.CustomCorrectnessMetric( - name="correctness", - system_prompt="You are a judge.", - ) - - mock_sample = MagicMock() - mock_sample.conversation_history = [] - mock_sample.question = "Q" - mock_sample.reference_answer = "A" - - mock_answer = MagicMock() - mock_answer.message = "A" - - result = metric(mock_sample, mock_answer) - - assert "correctness_reason" not in result - - @patch("server.api.utils.testbed_metrics.get_default_client") - @patch("server.api.utils.testbed_metrics.parse_json_output") - def test_call_keeps_reason_when_incorrect(self, mock_parse, mock_get_client): - """Should keep correctness_reason when answer is incorrect.""" - mock_client = MagicMock() - mock_client.complete.return_value = MagicMock(content='{}') - mock_get_client.return_value = mock_client - mock_parse.return_value = {"correctness": False, "correctness_reason": "Does not match"} - - metric = testbed_metrics.CustomCorrectnessMetric( - name="correctness", - system_prompt="You are a judge.", - ) - - mock_sample = MagicMock() - mock_sample.conversation_history = [] - mock_sample.question = "Q" - mock_sample.reference_answer = "A" - - mock_answer = MagicMock() - mock_answer.message = "Wrong" - - result = metric(mock_sample, mock_answer) - - assert result["correctness_reason"] == "Does not match" - - @patch("server.api.utils.testbed_metrics.get_default_client") - @patch("server.api.utils.testbed_metrics.parse_json_output") - def test_call_raises_on_non_boolean_correctness(self, mock_parse, mock_get_client): - """Should raise LLMGenerationError if correctness is not boolean.""" - mock_client = MagicMock() - mock_client.complete.return_value = MagicMock(content='{}') - mock_get_client.return_value = mock_client - mock_parse.return_value = {"correctness": "yes"} # String instead of bool - - metric = testbed_metrics.CustomCorrectnessMetric( - name="correctness", - system_prompt="You are a judge.", - ) - - mock_sample = MagicMock() - mock_sample.conversation_history = [] - mock_sample.question = "Q" - mock_sample.reference_answer = "A" - - mock_answer = MagicMock() - mock_answer.message = "A" - - with pytest.raises(LLMGenerationError) as exc_info: - metric(mock_sample, mock_answer) - - assert "Expected boolean" in str(exc_info.value) - - @patch("server.api.utils.testbed_metrics.get_default_client") - def test_call_reraises_llm_generation_error(self, mock_get_client): - """Should re-raise LLMGenerationError from LLM client.""" - mock_client = MagicMock() - mock_client.complete.side_effect = LLMGenerationError("LLM failed") - mock_get_client.return_value = mock_client - - metric = testbed_metrics.CustomCorrectnessMetric( - name="correctness", - system_prompt="You are a judge.", - ) - - mock_sample = MagicMock() - mock_sample.conversation_history = [] - mock_sample.question = "Q" - mock_sample.reference_answer = "A" - - mock_answer = MagicMock() - mock_answer.message = "A" - - with pytest.raises(LLMGenerationError): - metric(mock_sample, mock_answer) - - @patch("server.api.utils.testbed_metrics.get_default_client") - def test_call_wraps_other_exceptions(self, mock_get_client): - """Should wrap other exceptions in LLMGenerationError.""" - mock_client = MagicMock() - mock_client.complete.side_effect = RuntimeError("Unexpected error") - mock_get_client.return_value = mock_client - - metric = testbed_metrics.CustomCorrectnessMetric( - name="correctness", - system_prompt="You are a judge.", - ) - - mock_sample = MagicMock() - mock_sample.conversation_history = [] - mock_sample.question = "Q" - mock_sample.reference_answer = "A" - - mock_answer = MagicMock() - mock_answer.message = "A" - - with pytest.raises(LLMGenerationError) as exc_info: - metric(mock_sample, mock_answer) - - assert "Error while evaluating" in str(exc_info.value) - - @patch("server.api.utils.testbed_metrics.get_default_client") - @patch("server.api.utils.testbed_metrics.parse_json_output") - def test_call_uses_provided_llm_client(self, mock_parse, mock_get_client): - """Should use provided LLM client instead of default.""" - mock_provided_client = MagicMock() - mock_provided_client.complete.return_value = MagicMock(content='{}') - mock_parse.return_value = {"correctness": True} - - metric = testbed_metrics.CustomCorrectnessMetric( - name="correctness", - system_prompt="You are a judge.", - llm_client=mock_provided_client, - ) - - mock_sample = MagicMock() - mock_sample.conversation_history = [] - mock_sample.question = "Q" - mock_sample.reference_answer = "A" - - mock_answer = MagicMock() - mock_answer.message = "A" - - metric(mock_sample, mock_answer) - - mock_provided_client.complete.assert_called_once() - mock_get_client.assert_not_called() - - @patch("server.api.utils.testbed_metrics.get_default_client") - @patch("server.api.utils.testbed_metrics.parse_json_output") - def test_call_includes_conversation_history(self, mock_parse, mock_get_client): - """Should include conversation history in the prompt.""" - mock_client = MagicMock() - mock_client.complete.return_value = MagicMock(content='{}') - mock_get_client.return_value = mock_client - mock_parse.return_value = {"correctness": True} - - metric = testbed_metrics.CustomCorrectnessMetric( - name="correctness", - system_prompt="You are a judge.", - ) - - mock_sample = MagicMock() - mock_sample.conversation_history = [ - {"role": "user", "content": "Previous question"}, - {"role": "assistant", "content": "Previous answer"}, - ] - mock_sample.question = "Follow-up question" - mock_sample.reference_answer = "Expected answer" - - mock_answer = MagicMock() - mock_answer.message = "Actual answer" - - metric(mock_sample, mock_answer) - - call_args = mock_client.complete.call_args - user_message = call_args.kwargs["messages"][1].content - assert "Previous question" in user_message - assert "Previous answer" in user_message - assert "Follow-up question" in user_message diff --git a/test/unit/server/api/utils/test_utils_webscrape.py b/test/unit/server/api/utils/test_utils_webscrape.py deleted file mode 100644 index cd042992..00000000 --- a/test/unit/server/api/utils/test_utils_webscrape.py +++ /dev/null @@ -1,419 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/utils/webscrape.py -Tests for web scraping and content extraction utilities. -""" - -# pylint: disable=too-few-public-methods - -from test.unit.server.api.conftest import create_mock_aiohttp_session -from unittest.mock import patch, AsyncMock - -import pytest -from bs4 import BeautifulSoup - -from server.api.utils import webscrape - - -class TestNormalizeWs: - """Tests for the normalize_ws function.""" - - def test_normalize_ws_removes_extra_spaces(self): - """normalize_ws should collapse multiple spaces into one.""" - result = webscrape.normalize_ws("Hello world") - assert result == "Hello world" - - def test_normalize_ws_removes_newlines(self): - """normalize_ws should replace newlines with spaces.""" - result = webscrape.normalize_ws("Hello\n\nworld") - assert result == "Hello world" - - def test_normalize_ws_strips_whitespace(self): - """normalize_ws should strip leading/trailing whitespace.""" - result = webscrape.normalize_ws(" Hello world ") - assert result == "Hello world" - - def test_normalize_ws_handles_tabs(self): - """normalize_ws should handle tab characters.""" - result = webscrape.normalize_ws("Hello\t\tworld") - assert result == "Hello world" - - def test_normalize_ws_normalizes_unicode(self): - """normalize_ws should normalize unicode characters.""" - # NFKC normalization should convert full-width to half-width - result = webscrape.normalize_ws("Hello") # Full-width characters - assert result == "Hello" - - def test_normalize_ws_empty_string(self): - """normalize_ws should handle empty string.""" - result = webscrape.normalize_ws("") - assert result == "" - - -class TestCleanSoup: - """Tests for the clean_soup function.""" - - def test_clean_soup_removes_script_tags(self): - """clean_soup should remove script tags.""" - html = "

Content

" - soup = BeautifulSoup(html, "html.parser") - - webscrape.clean_soup(soup) - - assert soup.find("script") is None - assert soup.find("p") is not None - - def test_clean_soup_removes_style_tags(self): - """clean_soup should remove style tags.""" - html = "

Content

" - soup = BeautifulSoup(html, "html.parser") - - webscrape.clean_soup(soup) - - assert soup.find("style") is None - - def test_clean_soup_removes_noscript_tags(self): - """clean_soup should remove noscript tags.""" - html = "

Content

" - soup = BeautifulSoup(html, "html.parser") - - webscrape.clean_soup(soup) - - assert soup.find("noscript") is None - - def test_clean_soup_removes_nav_elements(self): - """clean_soup should remove navigation elements.""" - html = '

Content

' - soup = BeautifulSoup(html, "html.parser") - - webscrape.clean_soup(soup) - - assert soup.find("nav") is None - - def test_clean_soup_removes_elements_by_class(self): - """clean_soup should remove elements with bad class names.""" - html = '

Content

' - soup = BeautifulSoup(html, "html.parser") - - webscrape.clean_soup(soup) - - assert soup.find(class_="footer") is None - - def test_clean_soup_preserves_content(self): - """clean_soup should preserve main content.""" - html = "

Important content

" - soup = BeautifulSoup(html, "html.parser") - - webscrape.clean_soup(soup) - - assert soup.find("p") is not None - assert "Important content" in soup.get_text() - - -class TestHeadingLevel: - """Tests for the heading_level function.""" - - def test_heading_level_h1(self): - """heading_level should return 1 for h1.""" - soup = BeautifulSoup("

Title

", "html.parser") - tag = soup.find("h1") - - result = webscrape.heading_level(tag) - - assert result == 1 - - def test_heading_level_h2(self): - """heading_level should return 2 for h2.""" - soup = BeautifulSoup("

Title

", "html.parser") - tag = soup.find("h2") - - result = webscrape.heading_level(tag) - - assert result == 2 - - def test_heading_level_h6(self): - """heading_level should return 6 for h6.""" - soup = BeautifulSoup("
Title
", "html.parser") - tag = soup.find("h6") - - result = webscrape.heading_level(tag) - - assert result == 6 - - -class TestGroupBySections: - """Tests for the group_by_sections function.""" - - def test_group_by_sections_extracts_sections(self): - """group_by_sections should extract section content.""" - html = """ - -
-

Section Title

-

Paragraph 1

-

Paragraph 2

-
- - """ - soup = BeautifulSoup(html, "html.parser") - - result = webscrape.group_by_sections(soup) - - assert len(result) == 1 - assert result[0]["title"] == "Section Title" - assert "Paragraph 1" in result[0]["content"] - - def test_group_by_sections_handles_articles(self): - """group_by_sections should handle article tags.""" - html = """ - -
-

Article Title

-

Article content

-
- - """ - soup = BeautifulSoup(html, "html.parser") - - result = webscrape.group_by_sections(soup) - - assert len(result) == 1 - assert result[0]["title"] == "Article Title" - - def test_group_by_sections_no_sections(self): - """group_by_sections should return empty list when no sections.""" - html = "

Plain content

" - soup = BeautifulSoup(html, "html.parser") - - result = webscrape.group_by_sections(soup) - - assert not result - - -class TestTableToMarkdown: - """Tests for the table_to_markdown function.""" - - def test_table_to_markdown_basic_table(self): - """table_to_markdown should convert table to markdown.""" - html = """ - - - -
Header 1Header 2
Cell 1Cell 2
- """ - soup = BeautifulSoup(html, "html.parser") - table = soup.find("table") - - result = webscrape.table_to_markdown(table) - - assert "| Header 1 | Header 2 |" in result - assert "| --- | --- |" in result - assert "| Cell 1 | Cell 2 |" in result - - def test_table_to_markdown_empty_table(self): - """table_to_markdown should handle empty table.""" - html = "
" - soup = BeautifulSoup(html, "html.parser") - table = soup.find("table") - - result = webscrape.table_to_markdown(table) - - assert result == "" - - -class TestGroupByHeadings: - """Tests for the group_by_headings function.""" - - def test_group_by_headings_extracts_sections(self): - """group_by_headings should group content by heading.""" - html = """ - -

Section 1

-

Content 1

-

Section 2

-

Content 2

- - """ - soup = BeautifulSoup(html, "html.parser") - - result = webscrape.group_by_headings(soup) - - assert len(result) == 2 - assert result[0]["title"] == "Section 1" - assert result[1]["title"] == "Section 2" - - def test_group_by_headings_handles_lists(self): - """group_by_headings should include list items.""" - html = """ - -

List Section

-
    -
  • Item 1
  • -
  • Item 2
  • -
- - """ - soup = BeautifulSoup(html, "html.parser") - - result = webscrape.group_by_headings(soup) - - assert len(result) == 1 - assert "Item 1" in result[0]["content"] - - def test_group_by_headings_respects_hierarchy(self): - """group_by_headings should stop at same or higher level heading.""" - html = """ - -

Parent

-

Parent content

-

Child

-

Child content

-

Sibling

-

Sibling content

- - """ - soup = BeautifulSoup(html, "html.parser") - - result = webscrape.group_by_headings(soup) - - # h2 sections should not include content from sibling h2 - parent_section = next(s for s in result if s["title"] == "Parent") - assert "Sibling content" not in parent_section["content"] - - -class TestSectionsToMarkdown: - """Tests for the sections_to_markdown function.""" - - def test_sections_to_markdown_basic(self): - """sections_to_markdown should convert sections to markdown.""" - sections = [ - {"title": "Section 1", "level": 1, "paragraphs": ["Para 1"]}, - {"title": "Section 2", "level": 2, "paragraphs": ["Para 2"]}, - ] - - result = webscrape.sections_to_markdown(sections) - - assert "# Section 1" in result - assert "## Section 2" in result - - def test_sections_to_markdown_empty_list(self): - """sections_to_markdown should handle empty list.""" - result = webscrape.sections_to_markdown([]) - - assert result == "" - - -class TestSlugify: - """Tests for the slugify function.""" - - def test_slugify_basic(self): - """slugify should convert text to URL-safe slug.""" - result = webscrape.slugify("Hello World") - - assert result == "hello-world" - - def test_slugify_special_characters(self): - """slugify should remove special characters.""" - result = webscrape.slugify("Hello! World?") - - assert result == "hello-world" - - def test_slugify_max_length(self): - """slugify should respect max length.""" - long_text = "a" * 100 - result = webscrape.slugify(long_text, max_len=10) - - assert len(result) == 10 - - def test_slugify_empty_string(self): - """slugify should return 'page' for empty result.""" - result = webscrape.slugify("!!!") - - assert result == "page" - - def test_slugify_multiple_spaces(self): - """slugify should collapse multiple spaces/dashes.""" - result = webscrape.slugify("Hello World") - - assert result == "hello-world" - - -class TestFetchAndExtractParagraphs: - """Tests for the fetch_and_extract_paragraphs function.""" - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession") - async def test_fetch_and_extract_paragraphs_success(self, mock_session_class): - """fetch_and_extract_paragraphs should extract paragraphs from URL.""" - html = "

Paragraph 1

Paragraph 2

" - - mock_response = AsyncMock() - mock_response.text = AsyncMock(return_value=html) - create_mock_aiohttp_session(mock_session_class, mock_response) - - result = await webscrape.fetch_and_extract_paragraphs("https://example.com") - - assert len(result) == 2 - assert "Paragraph 1" in result - assert "Paragraph 2" in result - - -class TestFetchAndExtractSections: - """Tests for the fetch_and_extract_sections function.""" - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession") - async def test_fetch_and_extract_sections_with_sections(self, mock_session_class): - """fetch_and_extract_sections should extract sections from URL.""" - html = """ - -

Title

Content

- - """ - - mock_response = AsyncMock() - mock_response.text = AsyncMock(return_value=html) - create_mock_aiohttp_session(mock_session_class, mock_response) - - result = await webscrape.fetch_and_extract_sections("https://example.com") - - assert len(result) == 1 - assert result[0]["title"] == "Title" - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession") - async def test_fetch_and_extract_sections_falls_back_to_headings(self, mock_session_class): - """fetch_and_extract_sections should fall back to headings.""" - html = """ - -

Heading

-

Content

- - """ - - mock_response = AsyncMock() - mock_response.text = AsyncMock(return_value=html) - create_mock_aiohttp_session(mock_session_class, mock_response) - - result = await webscrape.fetch_and_extract_sections("https://example.com") - - assert len(result) == 1 - assert result[0]["title"] == "Heading" - - -class TestBadChunks: - """Tests for the BAD_CHUNKS constant.""" - - def test_bad_chunks_contains_common_elements(self): - """BAD_CHUNKS should contain common unwanted elements.""" - assert "nav" in webscrape.BAD_CHUNKS - assert "header" in webscrape.BAD_CHUNKS - assert "footer" in webscrape.BAD_CHUNKS - assert "ads" in webscrape.BAD_CHUNKS - assert "comment" in webscrape.BAD_CHUNKS - - def test_bad_chunks_is_list(self): - """BAD_CHUNKS should be a list.""" - assert isinstance(webscrape.BAD_CHUNKS, list) diff --git a/test/unit/server/api/v1/__init__.py b/test/unit/server/api/v1/__init__.py deleted file mode 100644 index a6ad55f3..00000000 --- a/test/unit/server/api/v1/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# v1 API unit test package diff --git a/test/unit/server/api/v1/test_v1_chat.py b/test/unit/server/api/v1/test_v1_chat.py deleted file mode 100644 index 9c0b9b75..00000000 --- a/test/unit/server/api/v1/test_v1_chat.py +++ /dev/null @@ -1,258 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/v1/chat.py -Tests for chat completion endpoints. -""" - -from unittest.mock import patch, MagicMock -import pytest -from fastapi.responses import StreamingResponse - -from server.api.v1 import chat - - -class TestChatPost: - """Tests for the chat_post endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.chat.chat.completion_generator") - async def test_chat_post_returns_last_message(self, mock_generator, make_chat_request): - """chat_post should return the final completion message.""" - request = make_chat_request(content="Hello") - mock_response = {"choices": [{"message": {"content": "Hi there!"}}]} - - async def mock_gen(): - yield mock_response - - mock_generator.return_value = mock_gen() - - result = await chat.chat_post(request=request, client="test_client") - - assert result == mock_response - mock_generator.assert_called_once_with("test_client", request, "completions") - - @pytest.mark.asyncio - @patch("server.api.v1.chat.chat.completion_generator") - async def test_chat_post_iterates_through_all_chunks(self, mock_generator, make_chat_request): - """chat_post should iterate through all chunks and return last.""" - request = make_chat_request(content="Hello") - - async def mock_gen(): - yield "chunk1" - yield "chunk2" - yield {"final": "response"} - - mock_generator.return_value = mock_gen() - - result = await chat.chat_post(request=request, client="test_client") - - assert result == {"final": "response"} - - @pytest.mark.asyncio - @patch("server.api.v1.chat.chat.completion_generator") - async def test_chat_post_uses_default_client(self, mock_generator, make_chat_request): - """chat_post should use 'server' as default client.""" - request = make_chat_request() - - async def mock_gen(): - yield {"response": "data"} - - mock_generator.return_value = mock_gen() - - await chat.chat_post(request=request, client="server") - - mock_generator.assert_called_once_with("server", request, "completions") - - -class TestChatStream: - """Tests for the chat_stream endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.chat.chat.completion_generator") - async def test_chat_stream_returns_streaming_response(self, mock_generator, make_chat_request): - """chat_stream should return a StreamingResponse.""" - request = make_chat_request(content="Hello") - - async def mock_gen(): - yield b"chunk1" - yield b"chunk2" - - mock_generator.return_value = mock_gen() - - result = await chat.chat_stream(request=request, client="test_client") - - assert isinstance(result, StreamingResponse) - assert result.media_type == "application/octet-stream" - - @pytest.mark.asyncio - @patch("server.api.v1.chat.chat.completion_generator") - async def test_chat_stream_calls_generator_with_streams_mode(self, mock_generator, make_chat_request): - """chat_stream should call generator with 'streams' mode.""" - request = make_chat_request() - - async def mock_gen(): - yield b"data" - - mock_generator.return_value = mock_gen() - - await chat.chat_stream(request=request, client="test_client") - - mock_generator.assert_called_once_with("test_client", request, "streams") - - -class TestChatHistoryClean: - """Tests for the chat_history_clean endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.chat.chatbot.chatbot_graph") - async def test_chat_history_clean_success(self, mock_graph): - """chat_history_clean should clear history and return confirmation.""" - mock_graph.update_state = MagicMock(return_value=None) - - result = await chat.chat_history_clean(client="test_client") - - assert len(result) == 1 - assert "forgotten" in result[0].content - assert result[0].role == "system" - - @pytest.mark.asyncio - @patch("server.api.v1.chat.chatbot.chatbot_graph") - async def test_chat_history_clean_updates_state_correctly(self, mock_graph): - """chat_history_clean should update state with correct values.""" - mock_graph.update_state = MagicMock(return_value=None) - - await chat.chat_history_clean(client="test_client") - - call_args = mock_graph.update_state.call_args - values = call_args[1]["values"] - - assert "messages" in values - assert values["cleaned_messages"] == [] - assert values["context_input"] == "" - assert values["documents"] == {} - assert values["final_response"] == {} - - @pytest.mark.asyncio - @patch("server.api.v1.chat.chatbot.chatbot_graph") - async def test_chat_history_clean_handles_key_error(self, mock_graph): - """chat_history_clean should handle KeyError gracefully.""" - mock_graph.update_state = MagicMock(side_effect=KeyError("thread not found")) - - result = await chat.chat_history_clean(client="nonexistent_client") - - assert len(result) == 1 - assert "no history" in result[0].content - assert result[0].role == "system" - - @pytest.mark.asyncio - @patch("server.api.v1.chat.chatbot.chatbot_graph") - async def test_chat_history_clean_uses_correct_thread_id(self, mock_graph): - """chat_history_clean should use client as thread_id.""" - mock_graph.update_state = MagicMock(return_value=None) - - await chat.chat_history_clean(client="my_client_id") - - call_args = mock_graph.update_state.call_args - # config is passed as keyword argument, RunnableConfig is dict-like - config = call_args.kwargs["config"] - - assert config["configurable"]["thread_id"] == "my_client_id" - - -class TestChatHistoryReturn: - """Tests for the chat_history_return endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.chat.chatbot.chatbot_graph") - @patch("server.api.v1.chat.convert_to_openai_messages") - async def test_chat_history_return_success(self, mock_convert, mock_graph): - """chat_history_return should return chat messages.""" - mock_messages = [ - MagicMock(content="Hello", role="user"), - MagicMock(content="Hi there", role="assistant"), - ] - mock_state = MagicMock() - mock_state.values = {"messages": mock_messages} - mock_graph.get_state = MagicMock(return_value=mock_state) - mock_convert.return_value = [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there"}, - ] - - result = await chat.chat_history_return(client="test_client") - - assert len(result) == 2 - mock_convert.assert_called_once_with(mock_messages) - - @pytest.mark.asyncio - @patch("server.api.v1.chat.chatbot.chatbot_graph") - async def test_chat_history_return_handles_key_error(self, mock_graph): - """chat_history_return should handle KeyError gracefully.""" - mock_graph.get_state = MagicMock(side_effect=KeyError("thread not found")) - - result = await chat.chat_history_return(client="nonexistent_client") - - assert len(result) == 1 - assert "no history" in result[0].content - assert result[0].role == "system" - - @pytest.mark.asyncio - @patch("server.api.v1.chat.chatbot.chatbot_graph") - async def test_chat_history_return_uses_correct_thread_id(self, mock_graph): - """chat_history_return should use client as thread_id.""" - mock_state = MagicMock() - mock_state.values = {"messages": []} - mock_graph.get_state = MagicMock(return_value=mock_state) - - with patch("server.api.v1.chat.convert_to_openai_messages", return_value=[]): - await chat.chat_history_return(client="my_client_id") - - call_args = mock_graph.get_state.call_args - # config is passed as keyword argument, RunnableConfig is dict-like - config = call_args.kwargs["config"] - - assert config["configurable"]["thread_id"] == "my_client_id" - - @pytest.mark.asyncio - @patch("server.api.v1.chat.chatbot.chatbot_graph") - @patch("server.api.v1.chat.convert_to_openai_messages") - async def test_chat_history_return_empty_history(self, mock_convert, mock_graph): - """chat_history_return should handle empty history.""" - mock_state = MagicMock() - mock_state.values = {"messages": []} - mock_graph.get_state = MagicMock(return_value=mock_state) - mock_convert.return_value = [] - - result = await chat.chat_history_return(client="test_client") - - assert result == [] - - -class TestRouterConfiguration: - """Tests for router configuration.""" - - def test_auth_router_exists(self): - """The auth router should be defined.""" - assert hasattr(chat, "auth") - - def test_auth_router_has_routes(self): - """The auth router should have registered routes.""" - routes = [route.path for route in chat.auth.routes] - - assert "/completions" in routes - assert "/streams" in routes - assert "/history" in routes - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(chat, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert chat.logger.name == "endpoints.v1.chat" diff --git a/test/unit/server/api/v1/test_v1_databases.py b/test/unit/server/api/v1/test_v1_databases.py deleted file mode 100644 index 98f957a4..00000000 --- a/test/unit/server/api/v1/test_v1_databases.py +++ /dev/null @@ -1,290 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/v1/databases.py -Tests for database configuration endpoints. - -Note: These tests mock utils_databases functions to test endpoint logic -(HTTP responses, error handling). The underlying database operations -are tested with real Oracle database in test_utils_databases.py. -""" - -from unittest.mock import patch, MagicMock -import pytest -from fastapi import HTTPException - -from server.api.v1 import databases - - -class TestDatabasesList: - """Tests for the databases_list endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.databases.utils_databases.get_databases") - async def test_databases_list_returns_all_databases(self, mock_get_databases, make_database): - """databases_list should return all configured databases.""" - db_list = [ - make_database(name="DB1"), - make_database(name="DB2"), - ] - mock_get_databases.return_value = db_list - - result = await databases.databases_list() - - assert result == db_list - mock_get_databases.assert_called_once_with(validate=False) - - @pytest.mark.asyncio - @patch("server.api.v1.databases.utils_databases.get_databases") - async def test_databases_list_returns_empty_list(self, mock_get_databases): - """databases_list should return empty list when no databases.""" - mock_get_databases.return_value = [] - - result = await databases.databases_list() - - assert result == [] - mock_get_databases.assert_called_once_with(validate=False) - - @pytest.mark.asyncio - @patch("server.api.v1.databases.utils_databases.get_databases") - async def test_databases_list_raises_404_on_value_error(self, mock_get_databases): - """databases_list should raise 404 when ValueError occurs.""" - mock_get_databases.side_effect = ValueError("No databases found") - - with pytest.raises(HTTPException) as exc_info: - await databases.databases_list() - - assert exc_info.value.status_code == 404 - mock_get_databases.assert_called_once_with(validate=False) - - -class TestDatabasesGet: - """Tests for the databases_get endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.databases.utils_databases.get_databases") - async def test_databases_get_returns_single_database(self, mock_get_databases, make_database): - """databases_get should return a single database by name.""" - database = make_database(name="TEST_DB") - mock_get_databases.return_value = database - - result = await databases.databases_get(name="TEST_DB") - - assert result == database - mock_get_databases.assert_called_once_with(db_name="TEST_DB", validate=True) - - @pytest.mark.asyncio - @patch("server.api.v1.databases.utils_databases.get_databases") - async def test_databases_get_raises_404_when_not_found(self, mock_get_databases): - """databases_get should raise 404 when database not found.""" - mock_get_databases.side_effect = ValueError("Database not found") - - with pytest.raises(HTTPException) as exc_info: - await databases.databases_get(name="NONEXISTENT") - - assert exc_info.value.status_code == 404 - mock_get_databases.assert_called_once_with(db_name="NONEXISTENT", validate=True) - - -class TestDatabasesUpdate: - """Tests for the databases_update endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.databases.utils_databases.get_databases") - @patch("server.api.v1.databases.utils_databases.connect") - @patch("server.api.v1.databases.utils_databases.disconnect") - async def test_databases_update_returns_updated_database( - self, mock_disconnect, mock_connect, mock_get_databases, make_database, make_database_auth - ): - """databases_update should return the updated database.""" - existing_db = make_database(name="TEST_DB", user="old_user") - # First call returns the single db, second call returns list for cleanup - mock_get_databases.side_effect = [existing_db, [existing_db]] - mock_connect.return_value = MagicMock() - - payload = make_database_auth(user="new_user", password="new_pass", dsn="localhost:1521/TEST") - - result = await databases.databases_update(name="TEST_DB", payload=payload) - - assert result.user == "new_user" - assert result.connected is True - - # Verify get_databases called twice: first to get target DB, second to get all DBs for cleanup - assert mock_get_databases.call_count == 2 - mock_get_databases.assert_any_call(db_name="TEST_DB", validate=False) - mock_get_databases.assert_any_call() - - # Verify connect was called with the payload (which has config_dir/wallet_location set from db) - mock_connect.assert_called_once() - connect_arg = mock_connect.call_args[0][0] - assert connect_arg.user == "new_user" - assert connect_arg.password == "new_pass" - assert connect_arg.dsn == "localhost:1521/TEST" - - # Verify disconnect was NOT called (no other databases with connections) - mock_disconnect.assert_not_called() - - @pytest.mark.asyncio - @patch("server.api.v1.databases.utils_databases.get_databases") - async def test_databases_update_raises_404_when_not_found(self, mock_get_databases, make_database_auth): - """databases_update should raise 404 when database not found.""" - mock_get_databases.side_effect = ValueError("Database not found") - - payload = make_database_auth() - - with pytest.raises(HTTPException) as exc_info: - await databases.databases_update(name="NONEXISTENT", payload=payload) - - assert exc_info.value.status_code == 404 - mock_get_databases.assert_called_once_with(db_name="NONEXISTENT", validate=False) - - @pytest.mark.asyncio - @patch("server.api.v1.databases.utils_databases.get_databases") - @patch("server.api.v1.databases.utils_databases.connect") - async def test_databases_update_raises_400_on_value_error( - self, mock_connect, mock_get_databases, make_database, make_database_auth - ): - """databases_update should raise 400 on ValueError during connect.""" - existing_db = make_database(name="TEST_DB") - mock_get_databases.return_value = existing_db - mock_connect.side_effect = ValueError("Invalid parameters") - - payload = make_database_auth() - - with pytest.raises(HTTPException) as exc_info: - await databases.databases_update(name="TEST_DB", payload=payload) - - assert exc_info.value.status_code == 400 - - # Verify get_databases was called to retrieve the target database - mock_get_databases.assert_called_once_with(db_name="TEST_DB", validate=False) - - # Verify connect was called with the payload - mock_connect.assert_called_once() - connect_arg = mock_connect.call_args[0][0] - assert connect_arg.user == payload.user - assert connect_arg.dsn == payload.dsn - - @pytest.mark.asyncio - @patch("server.api.v1.databases.utils_databases.get_databases") - @patch("server.api.v1.databases.utils_databases.connect") - async def test_databases_update_raises_401_on_permission_error( - self, mock_connect, mock_get_databases, make_database, make_database_auth - ): - """databases_update should raise 401 on PermissionError during connect.""" - existing_db = make_database(name="TEST_DB") - mock_get_databases.return_value = existing_db - mock_connect.side_effect = PermissionError("Access denied") - - payload = make_database_auth() - - with pytest.raises(HTTPException) as exc_info: - await databases.databases_update(name="TEST_DB", payload=payload) - - assert exc_info.value.status_code == 401 - - # Verify get_databases was called to retrieve the target database - mock_get_databases.assert_called_once_with(db_name="TEST_DB", validate=False) - - # Verify connect was called with the payload - mock_connect.assert_called_once() - connect_arg = mock_connect.call_args[0][0] - assert connect_arg.user == payload.user - assert connect_arg.dsn == payload.dsn - - @pytest.mark.asyncio - @patch("server.api.v1.databases.utils_databases.get_databases") - @patch("server.api.v1.databases.utils_databases.connect") - @patch("server.api.v1.databases.utils_databases.disconnect") - async def test_databases_update_disconnects_other_databases( - self, mock_disconnect, mock_connect, mock_get_databases, make_database, make_database_auth - ): - """databases_update should disconnect OTHER database connections, not the newly connected one. - - When connecting to a database, the system enforces single-connection mode: - only one database can be connected at a time. This test verifies that when - updating/connecting to TEST_DB, any existing connections on OTHER databases - are properly disconnected using their own connection objects. - - Expected behavior: - 1. Connect to TEST_DB with new connection - 2. For each other database with an active connection, disconnect it - 3. The disconnect call should receive the OTHER database's connection - 4. The newly connected database's connection should remain intact - """ - # Setup: TEST_DB is the database being updated - target_db = make_database(name="TEST_DB", user="old_user") - - # Setup: OTHER_DB has an existing connection that should be disconnected - other_db = make_database(name="OTHER_DB") - other_db_existing_connection = MagicMock(name="other_db_connection") - other_db.set_connection(other_db_existing_connection) - other_db.connected = True - - # Setup: ANOTHER_DB has no connection (should not trigger disconnect) - another_db = make_database(name="ANOTHER_DB") - another_db.connected = False - - # Mock: First call returns target DB, second call returns all DBs for cleanup - mock_get_databases.side_effect = [target_db, [target_db, other_db, another_db]] - - # Mock: New connection for TEST_DB - new_connection = MagicMock(name="new_test_db_connection") - mock_connect.return_value = new_connection - - # Mock: disconnect returns None (connection closed) - mock_disconnect.return_value = None - - payload = make_database_auth(user="new_user", password="new_pass", dsn="localhost:1521/TEST") - - # Execute - result = await databases.databases_update(name="TEST_DB", payload=payload) - - # Verify: Target database is connected with new connection - assert result.connected is True - assert result.user == "new_user" - - # Verify: disconnect was called exactly once (only OTHER_DB had a connection) - mock_disconnect.assert_called_once() - - # CRITICAL ASSERTION: disconnect must be called with OTHER_DB's connection, - # not the new TEST_DB connection - actual_disconnect_arg = mock_disconnect.call_args[0][0] - assert actual_disconnect_arg is other_db_existing_connection, ( - f"Expected disconnect to be called with other_db's connection, " - f"but was called with: {actual_disconnect_arg}" - ) - assert actual_disconnect_arg is not new_connection, ( - "disconnect should NOT be called with the newly created connection" - ) - - # Verify: OTHER_DB is now disconnected - assert other_db.connected is False - - -class TestRouterConfiguration: - """Tests for router configuration.""" - - def test_auth_router_exists(self): - """The auth router should be defined.""" - assert hasattr(databases, "auth") - - def test_auth_router_has_routes(self): - """The auth router should have registered routes.""" - routes = [route.path for route in databases.auth.routes] - - assert "" in routes - assert "/{name}" in routes - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(databases, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert databases.logger.name == "endpoints.v1.databases" diff --git a/test/unit/server/api/v1/test_v1_embed.py b/test/unit/server/api/v1/test_v1_embed.py deleted file mode 100644 index 17a442c9..00000000 --- a/test/unit/server/api/v1/test_v1_embed.py +++ /dev/null @@ -1,553 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/v1/embed.py -Tests for document embedding and vector store endpoints. -""" -# pylint: disable=protected-access -# pylint: disable=redefined-outer-name -# Pytest fixtures use parameter injection where fixture names match parameters - -from io import BytesIO -from pathlib import Path -from test.unit.server.api.conftest import create_mock_aiohttp_session -from unittest.mock import patch, MagicMock, AsyncMock -import json - -import pytest -from fastapi import HTTPException, UploadFile -from pydantic import HttpUrl - -from common.schema import DatabaseVectorStorage, VectorStoreRefreshRequest -from server.api.v1 import embed -from server.api.utils.databases import DbException - - -@pytest.fixture -def split_embed_mocks(): - """Fixture providing bundled mocks for split_embed tests.""" - with patch("server.api.v1.embed.utils_oci.get") as mock_oci_get, \ - patch("server.api.v1.embed.utils_embed.get_temp_directory") as mock_get_temp, \ - patch("server.api.v1.embed.utils_embed.load_and_split_documents") as mock_load_split, \ - patch("server.api.v1.embed.utils_models.get_client_embed") as mock_get_embed, \ - patch("server.api.v1.embed.functions.get_vs_table") as mock_get_vs_table, \ - patch("server.api.v1.embed.utils_embed.populate_vs") as mock_populate, \ - patch("server.api.v1.embed.utils_databases.get_client_database") as mock_get_db, \ - patch("shutil.rmtree") as mock_rmtree: - yield { - "oci_get": mock_oci_get, - "get_temp": mock_get_temp, - "load_split": mock_load_split, - "get_embed": mock_get_embed, - "get_vs_table": mock_get_vs_table, - "populate": mock_populate, - "get_db": mock_get_db, - "rmtree": mock_rmtree, - } - - -class TestExtractProviderErrorMessage: - """Tests for the _extract_provider_error_message helper function.""" - - def test_exception_with_message(self): - """Test extraction of exception with message""" - error = Exception("Something went wrong") - result = embed._extract_provider_error_message(error) - assert result == "Something went wrong" - - def test_exception_without_message(self): - """Test extraction of exception without message""" - error = ValueError() - result = embed._extract_provider_error_message(error) - assert result == "Error: ValueError" - - def test_openai_quota_exceeded(self): - """Test extraction of OpenAI quota exceeded error message""" - error_msg = ( - "Error code: 429 - {'error': {'message': 'You exceeded your current quota, " - "please check your plan and billing details.', 'type': 'insufficient_quota'}}" - ) - error = Exception(error_msg) - result = embed._extract_provider_error_message(error) - assert result == error_msg - - def test_openai_rate_limit(self): - """Test extraction of OpenAI rate limit error message""" - error_msg = "Rate limit exceeded. Please try again later." - error = Exception(error_msg) - result = embed._extract_provider_error_message(error) - assert result == error_msg - - def test_complex_error_message(self): - """Test extraction of complex multi-line error message""" - error_msg = "Connection failed\nTimeout: 30s\nHost: api.example.com" - error = Exception(error_msg) - result = embed._extract_provider_error_message(error) - assert result == error_msg - - @pytest.mark.parametrize( - "error_message", - [ - "OpenAI API key is invalid", - "Cohere API error occurred", - "OCI service error", - "Database connection failed", - "Rate limit exceeded for model xyz", - ], - ) - def test_various_error_messages(self, error_message): - """Test that various error messages are passed through correctly""" - error = Exception(error_message) - result = embed._extract_provider_error_message(error) - assert result == error_message - - -class TestEmbedDropVs: - """Tests for the embed_drop_vs endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_databases.get_client_database") - @patch("server.api.v1.embed.utils_databases.connect") - @patch("server.api.v1.embed.utils_databases.drop_vs") - async def test_embed_drop_vs_success(self, mock_drop, mock_connect, mock_get_db, make_database): - """embed_drop_vs should drop vector store and return success.""" - mock_db = make_database() - mock_get_db.return_value = mock_db - mock_connect.return_value = MagicMock() - mock_drop.return_value = None - - result = await embed.embed_drop_vs(vs="VS_TEST", client="test_client") - - assert result.status_code == 200 - mock_drop.assert_called_once() - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_databases.get_client_database") - @patch("server.api.v1.embed.utils_databases.connect") - @patch("server.api.v1.embed.utils_databases.drop_vs") - async def test_embed_drop_vs_raises_400_on_db_exception(self, mock_drop, mock_connect, mock_get_db, make_database): - """embed_drop_vs should raise 400 on DbException.""" - mock_db = make_database() - mock_get_db.return_value = mock_db - mock_connect.return_value = MagicMock() - mock_drop.side_effect = DbException(status_code=400, detail="Table not found") - - with pytest.raises(HTTPException) as exc_info: - await embed.embed_drop_vs(vs="VS_NONEXISTENT", client="test_client") - - assert exc_info.value.status_code == 400 - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_databases.get_client_database") - @patch("server.api.v1.embed.utils_databases.connect") - @patch("server.api.v1.embed.utils_databases.drop_vs") - async def test_embed_drop_vs_response_contains_vs_name(self, mock_drop, mock_connect, mock_get_db, make_database): - """embed_drop_vs response should contain vector store name.""" - mock_db = make_database() - mock_get_db.return_value = mock_db - mock_connect.return_value = MagicMock() - mock_drop.return_value = None - - result = await embed.embed_drop_vs(vs="VS_MY_STORE", client="test_client") - - body = json.loads(result.body) - assert "VS_MY_STORE" in body["message"] - - -class TestEmbedGetFiles: - """Tests for the embed_get_files endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_databases.get_client_database") - @patch("server.api.v1.embed.utils_embed.get_vector_store_files") - async def test_embed_get_files_success(self, mock_get_files, mock_get_db, make_database): - """embed_get_files should return file list.""" - mock_db = make_database() - mock_get_db.return_value = mock_db - mock_get_files.return_value = [ - {"filename": "file1.pdf", "chunks": 10}, - {"filename": "file2.txt", "chunks": 5}, - ] - - result = await embed.embed_get_files(vs="VS_TEST", client="test_client") - - assert result.status_code == 200 - mock_get_files.assert_called_once_with(mock_db, "VS_TEST") - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_databases.get_client_database") - @patch("server.api.v1.embed.utils_embed.get_vector_store_files") - async def test_embed_get_files_raises_400_on_exception(self, mock_get_files, mock_get_db, make_database): - """embed_get_files should raise 400 on exception.""" - mock_db = make_database() - mock_get_db.return_value = mock_db - mock_get_files.side_effect = Exception("Query failed") - - with pytest.raises(HTTPException) as exc_info: - await embed.embed_get_files(vs="VS_TEST", client="test_client") - - assert exc_info.value.status_code == 400 - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_databases.get_client_database") - @patch("server.api.v1.embed.utils_embed.get_vector_store_files") - async def test_embed_get_files_empty_list(self, mock_get_files, mock_get_db, make_database): - """embed_get_files should return empty list for empty vector store.""" - mock_db = make_database() - mock_get_db.return_value = mock_db - mock_get_files.return_value = [] - - result = await embed.embed_get_files(vs="VS_EMPTY", client="test_client") - - assert result.status_code == 200 - body = json.loads(result.body) - assert body == [] - - -class TestCommentVs: - """Tests for the comment_vs endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_databases.get_client_database") - @patch("server.api.v1.embed.utils_embed.update_vs_comment") - async def test_comment_vs_success(self, mock_update_comment, mock_get_db, make_database, make_vector_store): - """comment_vs should update vector store comment and return success.""" - mock_db = make_database() - mock_get_db.return_value = mock_db - mock_update_comment.return_value = None - - request = make_vector_store(vector_store="VS_TEST") - - result = await embed.comment_vs(request=request, client="test_client") - - assert result.status_code == 200 - body = json.loads(result.body) - assert "comment updated" in body["message"] - mock_update_comment.assert_called_once_with(vector_store=request, db_details=mock_db) - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_databases.get_client_database") - @patch("server.api.v1.embed.utils_embed.update_vs_comment") - async def test_comment_vs_calls_get_client_database( - self, mock_update_comment, mock_get_db, make_database, make_vector_store - ): - """comment_vs should call get_client_database with correct client.""" - mock_db = make_database() - mock_get_db.return_value = mock_db - mock_update_comment.return_value = None - - request = make_vector_store() - - await embed.comment_vs(request=request, client="my_client") - - mock_get_db.assert_called_once_with("my_client") - - -class TestStoreSqlFile: - """Tests for the store_sql_file endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_embed.get_temp_directory") - @patch("server.api.v1.embed.functions.run_sql_query") - async def test_store_sql_file_success(self, mock_run_sql, mock_get_temp, tmp_path): - """store_sql_file should execute SQL and return file path.""" - mock_get_temp.return_value = tmp_path - mock_run_sql.return_value = "result.csv" - - result = await embed.store_sql_file(request=["conn_str", "SELECT * FROM table"], client="test_client") - - assert result.status_code == 200 - body = json.loads(result.body) - assert "result.csv" in body - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_embed.get_temp_directory") - @patch("server.api.v1.embed.functions.run_sql_query") - async def test_store_sql_file_calls_run_sql_query(self, mock_run_sql, mock_get_temp, tmp_path): - """store_sql_file should call run_sql_query with correct params.""" - mock_get_temp.return_value = tmp_path - mock_run_sql.return_value = "output.csv" - - await embed.store_sql_file(request=["db_conn", "SELECT 1"], client="test_client") - - mock_run_sql.assert_called_once_with(db_conn="db_conn", query="SELECT 1", base_path=tmp_path) - - -class TestStoreWebFile: - """Tests for the store_web_file endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_embed.get_temp_directory") - @patch("server.api.v1.embed.web_parse.fetch_and_extract_sections") - @patch("server.api.v1.embed.web_parse.slugify") - @patch("aiohttp.ClientSession") - async def test_store_web_file_html_success( - self, mock_session_class, mock_slugify, mock_fetch_sections, mock_get_temp, tmp_path - ): - """store_web_file should fetch HTML and extract sections.""" - mock_get_temp.return_value = tmp_path - mock_slugify.return_value = "test-page" - mock_fetch_sections.return_value = [{"title": "Section 1", "content": "Content 1"}] - - mock_response = AsyncMock() - mock_response.headers = {"Content-Type": "text/html"} - mock_response.read = AsyncMock(return_value=b"") - create_mock_aiohttp_session(mock_session_class, mock_response) - - result = await embed.store_web_file(request=[HttpUrl("https://example.com/page")], client="test_client") - - assert result.status_code == 200 - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_embed.get_temp_directory") - @patch("aiohttp.ClientSession") - async def test_store_web_file_pdf_success(self, mock_session_class, mock_get_temp, tmp_path): - """store_web_file should download PDF files.""" - mock_get_temp.return_value = tmp_path - - mock_response = AsyncMock() - mock_response.headers = {"Content-Type": "application/pdf"} - mock_response.read = AsyncMock(return_value=b"%PDF-1.4") - create_mock_aiohttp_session(mock_session_class, mock_response) - - result = await embed.store_web_file(request=[HttpUrl("https://example.com/doc.pdf")], client="test_client") - - assert result.status_code == 200 - - -class TestStoreLocalFile: - """Tests for the store_local_file endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_embed.get_temp_directory") - async def test_store_local_file_success(self, mock_get_temp, tmp_path): - """store_local_file should save uploaded files.""" - mock_get_temp.return_value = tmp_path - - mock_file = UploadFile(file=BytesIO(b"Test content"), filename="test.txt") - - result = await embed.store_local_file(files=[mock_file], client="test_client") - - assert result.status_code == 200 - body = json.loads(result.body) - assert "test.txt" in body - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_embed.get_temp_directory") - async def test_store_local_file_creates_metadata(self, mock_get_temp, tmp_path): - """store_local_file should create metadata file.""" - mock_get_temp.return_value = tmp_path - - mock_file = UploadFile(file=BytesIO(b"Test content"), filename="test.txt") - - await embed.store_local_file(files=[mock_file], client="test_client") - - metadata_file = tmp_path / ".file_metadata.json" - assert metadata_file.exists() - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_embed.get_temp_directory") - async def test_store_local_file_multiple_files(self, mock_get_temp, tmp_path): - """store_local_file should handle multiple files.""" - mock_get_temp.return_value = tmp_path - - files = [ - UploadFile(file=BytesIO(b"Content 1"), filename="file1.txt"), - UploadFile(file=BytesIO(b"Content 2"), filename="file2.txt"), - ] - - result = await embed.store_local_file(files=files, client="test_client") - - body = json.loads(result.body) - assert len(body) == 2 - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_embed.get_temp_directory") - async def test_store_local_file_metadata_excludes_metadata_file(self, mock_get_temp, tmp_path): - """store_local_file should not include metadata file in response.""" - mock_get_temp.return_value = tmp_path - - mock_file = UploadFile(file=BytesIO(b"Content"), filename="test.txt") - - result = await embed.store_local_file(files=[mock_file], client="test_client") - - body = json.loads(result.body) - assert ".file_metadata.json" not in body - - -class TestSplitEmbed: - """Tests for the split_embed endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_oci.get") - @patch("server.api.v1.embed.utils_embed.get_temp_directory") - async def test_split_embed_raises_404_when_no_files(self, mock_get_temp, mock_oci_get, tmp_path, make_oci_config): - """split_embed should raise 404 when no files found.""" - mock_oci_get.return_value = make_oci_config() - mock_get_temp.return_value = tmp_path # Empty directory - - request = DatabaseVectorStorage(model="text-embedding-3", chunk_size=1000, chunk_overlap=200) - - with pytest.raises(HTTPException) as exc_info: - await embed.split_embed(request=request, rate_limit=0, client="test_client") - - assert exc_info.value.status_code == 404 - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_oci.get") - @patch("server.api.v1.embed.utils_embed.get_temp_directory") - async def test_split_embed_raises_404_when_folder_not_found(self, mock_get_temp, mock_oci_get, make_oci_config): - """split_embed should raise 404 when folder not found.""" - mock_oci_get.return_value = make_oci_config() - mock_get_temp.return_value = Path("/nonexistent/path") - - request = DatabaseVectorStorage(model="text-embedding-3", chunk_size=1000, chunk_overlap=200) - - with pytest.raises(HTTPException) as exc_info: - await embed.split_embed(request=request, rate_limit=0, client="test_client") - - assert exc_info.value.status_code == 404 - - @pytest.mark.asyncio - async def test_split_embed_success( - self, split_embed_mocks, tmp_path, make_oci_config, make_database - ): - """split_embed should process files and populate vector store.""" - mocks = split_embed_mocks - mocks["oci_get"].return_value = make_oci_config() - mocks["get_temp"].return_value = tmp_path - mocks["load_split"].return_value = (["doc1", "doc2"], None) - mocks["get_embed"].return_value = MagicMock() - mocks["get_vs_table"].return_value = ("VS_TEST", "test_alias") - mocks["populate"].return_value = None - mocks["get_db"].return_value = make_database() - - # Create a test file - (tmp_path / "test.txt").write_text("Test content") - - request = DatabaseVectorStorage(model="text-embedding-3", chunk_size=1000, chunk_overlap=200) - - result = await embed.split_embed(request=request, rate_limit=0, client="test_client") - - assert result.status_code == 200 - mocks["populate"].assert_called_once() - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_oci.get") - @patch("server.api.v1.embed.utils_embed.get_temp_directory") - @patch("server.api.v1.embed.utils_embed.load_and_split_documents") - @patch("shutil.rmtree") - async def test_split_embed_raises_500_on_value_error( - self, _mock_rmtree, mock_load_split, mock_get_temp, mock_oci_get, tmp_path, make_oci_config - ): - """split_embed should raise 500 on ValueError during processing.""" - mock_oci_get.return_value = make_oci_config() - mock_get_temp.return_value = tmp_path - mock_load_split.side_effect = ValueError("Invalid document format") - - # Create a test file - (tmp_path / "test.txt").write_text("Test content") - - request = DatabaseVectorStorage(model="text-embedding-3", chunk_size=1000, chunk_overlap=200) - - with pytest.raises(HTTPException) as exc_info: - await embed.split_embed(request=request, rate_limit=0, client="test_client") - - assert exc_info.value.status_code == 500 - - -class TestRefreshVectorStore: - """Tests for the refresh_vector_store endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_oci.get") - @patch("server.api.v1.embed.utils_databases.get_client_database") - @patch("server.api.v1.embed.utils_embed.get_vector_store_by_alias") - @patch("server.api.v1.embed.utils_oci.get_bucket_objects_with_metadata") - async def test_refresh_vector_store_no_files( - self, - mock_get_objects, - mock_get_vs, - mock_get_db, - mock_oci_get, - make_oci_config, - make_database, - make_vector_store, - ): - """refresh_vector_store should return success when no files.""" - mock_oci_get.return_value = make_oci_config() - mock_get_db.return_value = make_database() - mock_get_vs.return_value = make_vector_store() - mock_get_objects.return_value = [] - - request = VectorStoreRefreshRequest(vector_store_alias="test_alias", bucket_name="test-bucket") - - result = await embed.refresh_vector_store(request=request, client="test_client") - - assert result.status_code == 200 - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_oci.get") - async def test_refresh_vector_store_raises_400_on_value_error(self, mock_oci_get): - """refresh_vector_store should raise 400 on ValueError.""" - mock_oci_get.side_effect = ValueError("Invalid config") - - request = VectorStoreRefreshRequest(vector_store_alias="test_alias", bucket_name="test-bucket") - - with pytest.raises(HTTPException) as exc_info: - await embed.refresh_vector_store(request=request, client="test_client") - - assert exc_info.value.status_code == 400 - - @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_oci.get") - @patch("server.api.v1.embed.utils_databases.get_client_database") - @patch("server.api.v1.embed.utils_embed.get_vector_store_by_alias") - async def test_refresh_vector_store_raises_500_on_db_exception( - self, mock_get_vs, mock_get_db, mock_oci_get, make_oci_config, make_database - ): - """refresh_vector_store should raise 500 on DbException.""" - mock_oci_get.return_value = make_oci_config() - mock_get_db.return_value = make_database() - mock_get_vs.side_effect = DbException(status_code=500, detail="Database error") - - request = VectorStoreRefreshRequest(vector_store_alias="test_alias", bucket_name="test-bucket") - - with pytest.raises(HTTPException) as exc_info: - await embed.refresh_vector_store(request=request, client="test_client") - - assert exc_info.value.status_code == 500 - - -class TestRouterConfiguration: - """Tests for router configuration.""" - - def test_auth_router_exists(self): - """The auth router should be defined.""" - assert hasattr(embed, "auth") - - def test_auth_router_has_routes(self): - """The auth router should have registered routes.""" - routes = [route.path for route in embed.auth.routes] - - assert "/{vs}" in routes - assert "/{vs}/files" in routes - assert "/comment" in routes - assert "/sql/store" in routes - assert "/web/store" in routes - assert "/local/store" in routes - assert "/" in routes - assert "/refresh" in routes - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(embed, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert embed.logger.name == "api.v1.embed" diff --git a/test/unit/server/api/v1/test_v1_mcp.py b/test/unit/server/api/v1/test_v1_mcp.py deleted file mode 100644 index dc10c82b..00000000 --- a/test/unit/server/api/v1/test_v1_mcp.py +++ /dev/null @@ -1,169 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/v1/mcp.py -Tests for MCP (Model Context Protocol) endpoints. -""" - -# pylint: disable=too-few-public-methods - -from unittest.mock import patch, MagicMock, AsyncMock -import pytest - -from server.api.v1 import mcp - - -class TestGetMcp: - """Tests for the get_mcp dependency function.""" - - def test_get_mcp_returns_fastmcp_app(self): - """get_mcp should return the FastMCP app from request state.""" - mock_request = MagicMock() - mock_fastmcp = MagicMock() - mock_request.app.state.fastmcp_app = mock_fastmcp - - result = mcp.get_mcp(mock_request) - - assert result == mock_fastmcp - - -class TestGetClient: - """Tests for the get_client endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.mcp.utils_mcp.get_client") - async def test_get_client_returns_config(self, mock_get_client): - """get_client should return MCP client configuration.""" - expected_config = { - "mcpServers": { - "optimizer": { - "type": "streamableHttp", - "transport": "streamable_http", - "url": "http://127.0.0.1:8000/mcp/", - "headers": {"Authorization": "Bearer test-key"}, - } - } - } - mock_get_client.return_value = expected_config - - result = await mcp.get_client(server="http://127.0.0.1", port=8000) - - assert result == expected_config - mock_get_client.assert_called_once_with("http://127.0.0.1", 8000) - - @pytest.mark.asyncio - @patch("server.api.v1.mcp.utils_mcp.get_client") - async def test_get_client_with_default_params(self, mock_get_client): - """get_client should use default parameters.""" - mock_get_client.return_value = {} - - await mcp.get_client() - - mock_get_client.assert_called_once_with(None, None) - - -class TestGetTools: - """Tests for the get_tools endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.mcp.Client") - async def test_get_tools_returns_tool_list(self, mock_client_class, mock_fastmcp): - """get_tools should return list of MCP tools.""" - mock_tool1 = MagicMock() - mock_tool1.model_dump.return_value = {"name": "optimizer_tool1"} - mock_tool2 = MagicMock() - mock_tool2.model_dump.return_value = {"name": "optimizer_tool2"} - - mock_client = MagicMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mock_client.list_tools = AsyncMock(return_value=[mock_tool1, mock_tool2]) - mock_client.close = AsyncMock() - mock_client_class.return_value = mock_client - - result = await mcp.get_tools(mcp_engine=mock_fastmcp) - - assert len(result) == 2 - mock_client.close.assert_called_once() - - @pytest.mark.asyncio - @patch("server.api.v1.mcp.Client") - async def test_get_tools_returns_empty_list(self, mock_client_class, mock_fastmcp): - """get_tools should return empty list when no tools.""" - mock_client = MagicMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mock_client.list_tools = AsyncMock(return_value=[]) - mock_client.close = AsyncMock() - mock_client_class.return_value = mock_client - - result = await mcp.get_tools(mcp_engine=mock_fastmcp) - - assert result == [] - - -class TestMcpListResources: - """Tests for the mcp_list_resources endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.mcp.Client") - async def test_mcp_list_resources_returns_resource_list(self, mock_client_class, mock_fastmcp): - """mcp_list_resources should return list of resources.""" - mock_resource = MagicMock() - mock_resource.model_dump.return_value = {"name": "test_resource", "uri": "resource://test"} - - mock_client = MagicMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mock_client.list_resources = AsyncMock(return_value=[mock_resource]) - mock_client.close = AsyncMock() - mock_client_class.return_value = mock_client - - result = await mcp.mcp_list_resources(mcp_engine=mock_fastmcp) - - assert len(result) == 1 - mock_client.close.assert_called_once() - - @pytest.mark.asyncio - @patch("server.api.v1.mcp.Client") - async def test_mcp_list_resources_returns_empty_list(self, mock_client_class, mock_fastmcp): - """mcp_list_resources should return empty list when no resources.""" - mock_client = MagicMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mock_client.list_resources = AsyncMock(return_value=[]) - mock_client.close = AsyncMock() - mock_client_class.return_value = mock_client - - result = await mcp.mcp_list_resources(mcp_engine=mock_fastmcp) - - assert result == [] - - -class TestRouterConfiguration: - """Tests for router configuration.""" - - def test_auth_router_exists(self): - """The auth router should be defined.""" - assert hasattr(mcp, "auth") - - def test_auth_router_has_routes(self): - """The auth router should have registered routes.""" - routes = [route.path for route in mcp.auth.routes] - - assert "/client" in routes - assert "/tools" in routes - assert "/resources" in routes - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(mcp, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert mcp.logger.name == "api.v1.mcp" diff --git a/test/unit/server/api/v1/test_v1_mcp_prompts.py b/test/unit/server/api/v1/test_v1_mcp_prompts.py deleted file mode 100644 index 46c518c8..00000000 --- a/test/unit/server/api/v1/test_v1_mcp_prompts.py +++ /dev/null @@ -1,229 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/v1/mcp_prompts.py -Tests for MCP prompt management endpoints. -""" - -from unittest.mock import patch, MagicMock, AsyncMock -import pytest -from fastapi import HTTPException - -from server.api.v1 import mcp_prompts - - -class TestMcpListPrompts: - """Tests for the mcp_list_prompts endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.mcp_prompts.utils_mcp.list_prompts") - async def test_mcp_list_prompts_metadata_only(self, mock_list_prompts, mock_fastmcp): - """mcp_list_prompts should return metadata only when full=False.""" - mock_prompt = MagicMock() - mock_prompt.name = "optimizer_test-prompt" - mock_prompt.model_dump.return_value = {"name": "optimizer_test-prompt", "description": "Test"} - mock_list_prompts.return_value = [mock_prompt] - - result = await mcp_prompts.mcp_list_prompts(mcp_engine=mock_fastmcp, full=False) - - assert len(result) == 1 - assert result[0]["name"] == "optimizer_test-prompt" - mock_list_prompts.assert_called_once_with(mock_fastmcp) - - @pytest.mark.asyncio - @patch("server.api.v1.mcp_prompts.utils_settings.get_mcp_prompts_with_overrides") - async def test_mcp_list_prompts_full(self, mock_get_prompts, mock_fastmcp, make_mcp_prompt): - """mcp_list_prompts should return full prompts with text when full=True.""" - mock_prompt = make_mcp_prompt(name="optimizer_test-prompt") - mock_get_prompts.return_value = [mock_prompt] - - result = await mcp_prompts.mcp_list_prompts(mcp_engine=mock_fastmcp, full=True) - - assert len(result) == 1 - assert "text" in result[0] - mock_get_prompts.assert_called_once_with(mock_fastmcp) - - @pytest.mark.asyncio - @patch("server.api.v1.mcp_prompts.utils_mcp.list_prompts") - async def test_mcp_list_prompts_filters_non_optimizer_prompts(self, mock_list_prompts, mock_fastmcp): - """mcp_list_prompts should filter out non-optimizer prompts.""" - optimizer_prompt = MagicMock() - optimizer_prompt.name = "optimizer_test-prompt" - optimizer_prompt.model_dump.return_value = {"name": "optimizer_test-prompt"} - - other_prompt = MagicMock() - other_prompt.name = "other-prompt" - other_prompt.model_dump.return_value = {"name": "other-prompt"} - - mock_list_prompts.return_value = [optimizer_prompt, other_prompt] - - result = await mcp_prompts.mcp_list_prompts(mcp_engine=mock_fastmcp, full=False) - - assert len(result) == 1 - assert result[0]["name"] == "optimizer_test-prompt" - - @pytest.mark.asyncio - @patch("server.api.v1.mcp_prompts.utils_mcp.list_prompts") - async def test_mcp_list_prompts_empty_list(self, mock_list_prompts, mock_fastmcp): - """mcp_list_prompts should return empty list when no prompts.""" - mock_list_prompts.return_value = [] - - result = await mcp_prompts.mcp_list_prompts(mcp_engine=mock_fastmcp, full=False) - - assert result == [] - - -class TestMcpGetPrompt: - """Tests for the mcp_get_prompt endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.mcp_prompts.Client") - async def test_mcp_get_prompt_success(self, mock_client_class, mock_fastmcp): - """mcp_get_prompt should return prompt content.""" - mock_prompt_result = MagicMock() - mock_prompt_result.messages = [{"role": "user", "content": "Test content"}] - - mock_client = MagicMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mock_client.get_prompt = AsyncMock(return_value=mock_prompt_result) - mock_client.close = AsyncMock() - mock_client_class.return_value = mock_client - - result = await mcp_prompts.mcp_get_prompt(name="optimizer_test-prompt", mcp_engine=mock_fastmcp) - - assert result == mock_prompt_result - mock_client.get_prompt.assert_called_once_with(name="optimizer_test-prompt") - - @pytest.mark.asyncio - @patch("server.api.v1.mcp_prompts.Client") - async def test_mcp_get_prompt_closes_client(self, mock_client_class, mock_fastmcp): - """mcp_get_prompt should close client after use.""" - mock_client = MagicMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mock_client.get_prompt = AsyncMock(return_value=MagicMock()) - mock_client.close = AsyncMock() - mock_client_class.return_value = mock_client - - await mcp_prompts.mcp_get_prompt(name="test-prompt", mcp_engine=mock_fastmcp) - - mock_client.close.assert_called_once() - - -class TestMcpUpdatePrompt: - """Tests for the mcp_update_prompt endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.mcp_prompts.Client") - @patch("server.api.v1.mcp_prompts.cache") - async def test_mcp_update_prompt_success(self, mock_cache, mock_client_class, mock_fastmcp): - """mcp_update_prompt should update prompt and return success.""" - mock_prompt = MagicMock() - mock_prompt.name = "optimizer_test-prompt" - - mock_client = MagicMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mock_client.list_prompts = AsyncMock(return_value=[mock_prompt]) - mock_client_class.return_value = mock_client - - payload = {"instructions": "You are a helpful assistant."} - - result = await mcp_prompts.mcp_update_prompt( - name="optimizer_test-prompt", payload=payload, mcp_engine=mock_fastmcp - ) - - assert result["name"] == "optimizer_test-prompt" - assert "updated successfully" in result["message"] - mock_cache.set_override.assert_called_once_with("optimizer_test-prompt", "You are a helpful assistant.") - - @pytest.mark.asyncio - @patch("server.api.v1.mcp_prompts.Client") - async def test_mcp_update_prompt_missing_instructions(self, _mock_client_class, mock_fastmcp): - """mcp_update_prompt should raise 400 when instructions missing.""" - payload = {"other_field": "value"} - - with pytest.raises(HTTPException) as exc_info: - await mcp_prompts.mcp_update_prompt(name="test-prompt", payload=payload, mcp_engine=mock_fastmcp) - - assert exc_info.value.status_code == 400 - assert "instructions" in str(exc_info.value.detail) - - @pytest.mark.asyncio - @patch("server.api.v1.mcp_prompts.Client") - async def test_mcp_update_prompt_not_found(self, mock_client_class, mock_fastmcp): - """mcp_update_prompt should raise 404 when prompt not found.""" - mock_client = MagicMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mock_client.list_prompts = AsyncMock(return_value=[]) - mock_client_class.return_value = mock_client - - payload = {"instructions": "New instructions"} - - with pytest.raises(HTTPException) as exc_info: - await mcp_prompts.mcp_update_prompt(name="nonexistent-prompt", payload=payload, mcp_engine=mock_fastmcp) - - assert exc_info.value.status_code == 404 - - @pytest.mark.asyncio - @patch("server.api.v1.mcp_prompts.Client") - @patch("server.api.v1.mcp_prompts.cache") - async def test_mcp_update_prompt_handles_exception(self, mock_cache, mock_client_class, mock_fastmcp): - """mcp_update_prompt should raise 500 on unexpected exception.""" - mock_prompt = MagicMock() - mock_prompt.name = "optimizer_test-prompt" - - mock_client = MagicMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mock_client.list_prompts = AsyncMock(return_value=[mock_prompt]) - mock_client_class.return_value = mock_client - - mock_cache.set_override.side_effect = RuntimeError("Cache error") - - payload = {"instructions": "New instructions"} - - with pytest.raises(HTTPException) as exc_info: - await mcp_prompts.mcp_update_prompt(name="optimizer_test-prompt", payload=payload, mcp_engine=mock_fastmcp) - - assert exc_info.value.status_code == 500 - - @pytest.mark.asyncio - async def test_mcp_update_prompt_none_instructions(self, mock_fastmcp): - """mcp_update_prompt should raise 400 when instructions is None.""" - payload = {"instructions": None} - - with pytest.raises(HTTPException) as exc_info: - await mcp_prompts.mcp_update_prompt(name="test-prompt", payload=payload, mcp_engine=mock_fastmcp) - - assert exc_info.value.status_code == 400 - - -class TestRouterConfiguration: - """Tests for router configuration.""" - - def test_auth_router_exists(self): - """The auth router should be defined.""" - assert hasattr(mcp_prompts, "auth") - - def test_auth_router_has_routes(self): - """The auth router should have registered routes.""" - routes = [route.path for route in mcp_prompts.auth.routes] - - assert "/prompts" in routes - assert "/prompts/{name}" in routes - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(mcp_prompts, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert mcp_prompts.logger.name == "api.v1.mcp_prompts" diff --git a/test/unit/server/api/v1/test_v1_models.py b/test/unit/server/api/v1/test_v1_models.py deleted file mode 100644 index 6a4f721e..00000000 --- a/test/unit/server/api/v1/test_v1_models.py +++ /dev/null @@ -1,254 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/v1/models.py -Tests for model configuration endpoints. -""" - -import json -from unittest.mock import patch - -import pytest -from fastapi import HTTPException - -from server.api.v1 import models -from server.api.utils import models as utils_models - - -class TestModelsList: - """Tests for the models_list endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.models.utils_models.get") - async def test_models_list_returns_all_models(self, mock_get, make_model): - """models_list should return all configured models.""" - model_list = [ - make_model(model_id="gpt-4", provider="openai"), - make_model(model_id="claude-3", provider="anthropic"), - ] - mock_get.return_value = model_list - - result = await models.models_list() - - assert result == model_list - mock_get.assert_called_once() - - @pytest.mark.asyncio - @patch("server.api.v1.models.utils_models.get") - async def test_models_list_with_type_filter(self, mock_get): - """models_list should filter by model type when provided.""" - mock_get.return_value = [] - - await models.models_list(model_type="ll") - - mock_get.assert_called_once() - # Verify the model_type was passed (FastAPI Query wraps the value) - call_kwargs = mock_get.call_args.kwargs - assert call_kwargs.get("model_type") == "ll" - - @pytest.mark.asyncio - @patch("server.api.v1.models.utils_models.get") - async def test_models_list_with_include_disabled(self, mock_get): - """models_list should include disabled models when requested.""" - mock_get.return_value = [] - - await models.models_list(include_disabled=True) - - mock_get.assert_called_once() - # Verify the include_disabled was passed - call_kwargs = mock_get.call_args.kwargs - assert call_kwargs.get("include_disabled") is True - - -class TestModelsSupported: - """Tests for the models_supported endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.models.utils_models.get_supported") - async def test_models_supported_returns_supported_list(self, mock_get_supported): - """models_supported should return list of supported models.""" - supported_models = [ - {"provider": "openai", "models": ["gpt-4", "gpt-4o"]}, - ] - mock_get_supported.return_value = supported_models - - result = await models.models_supported(model_provider="openai") - - assert result == supported_models - mock_get_supported.assert_called_once() - - @pytest.mark.asyncio - @patch("server.api.v1.models.utils_models.get_supported") - async def test_models_supported_filters_by_type(self, mock_get_supported): - """models_supported should filter by model type when provided.""" - mock_get_supported.return_value = [] - - await models.models_supported(model_provider="openai", model_type="ll") - - mock_get_supported.assert_called_once() - call_kwargs = mock_get_supported.call_args.kwargs - assert call_kwargs.get("model_provider") == "openai" - assert call_kwargs.get("model_type") == "ll" - - -class TestModelsGet: - """Tests for the models_get endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.models.utils_models.get") - async def test_models_get_returns_single_model(self, mock_get, make_model): - """models_get should return a single model by ID.""" - model = make_model(model_id="gpt-4", provider="openai") - mock_get.return_value = (model,) # Returns a tuple that unpacks - - result = await models.models_get(model_provider="openai", model_id="gpt-4") - - assert result == model - mock_get.assert_called_once_with(model_provider="openai", model_id="gpt-4") - - @pytest.mark.asyncio - @patch("server.api.v1.models.utils_models.get") - async def test_models_get_raises_404_when_not_found(self, mock_get): - """models_get should raise 404 when model not found.""" - mock_get.side_effect = utils_models.UnknownModelError("Model not found") - - with pytest.raises(HTTPException) as exc_info: - await models.models_get(model_provider="openai", model_id="nonexistent") - - assert exc_info.value.status_code == 404 - - @pytest.mark.asyncio - @patch("server.api.v1.models.utils_models.get") - async def test_models_get_raises_404_on_multiple_results(self, mock_get, make_model): - """models_get should raise 404 when multiple models match.""" - # Returning a tuple with more than 1 element causes ValueError on unpack - mock_get.return_value = (make_model(), make_model()) - - with pytest.raises(HTTPException) as exc_info: - await models.models_get(model_provider="openai", model_id="gpt-4") - - assert exc_info.value.status_code == 404 - - -class TestModelsUpdate: - """Tests for the models_update endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.models.utils_models.update") - async def test_models_update_returns_updated_model(self, mock_update, make_model): - """models_update should return the updated model.""" - updated_model = make_model(model_id="gpt-4", provider="openai", enabled=False) - mock_update.return_value = updated_model - - payload = make_model(model_id="gpt-4", provider="openai") - result = await models.models_update(payload=payload) - - assert result == updated_model - mock_update.assert_called_once_with(payload=payload) - - @pytest.mark.asyncio - @patch("server.api.v1.models.utils_models.update") - async def test_models_update_raises_404_when_not_found(self, mock_update, make_model): - """models_update should raise 404 when model not found.""" - mock_update.side_effect = utils_models.UnknownModelError("Model not found") - - payload = make_model(model_id="nonexistent", provider="openai") - - with pytest.raises(HTTPException) as exc_info: - await models.models_update(payload=payload) - - assert exc_info.value.status_code == 404 - - @pytest.mark.asyncio - @patch("server.api.v1.models.utils_models.update") - async def test_models_update_raises_422_on_unreachable_url(self, mock_update, make_model): - """models_update should raise 422 when API URL is unreachable.""" - mock_update.side_effect = utils_models.URLUnreachableError("URL unreachable") - - payload = make_model(model_id="gpt-4", provider="openai") - - with pytest.raises(HTTPException) as exc_info: - await models.models_update(payload=payload) - - assert exc_info.value.status_code == 422 - - -class TestModelsCreate: - """Tests for the models_create endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.models.utils_models.create") - async def test_models_create_returns_new_model(self, mock_create, make_model): - """models_create should return newly created model.""" - new_model = make_model(model_id="new-model", provider="openai") - mock_create.return_value = new_model - - result = await models.models_create(payload=make_model(model_id="new-model", provider="openai")) - - assert result == new_model - - @pytest.mark.asyncio - @patch("server.api.v1.models.utils_models.create") - async def test_models_create_raises_409_on_duplicate(self, mock_create, make_model): - """models_create should raise 409 when model already exists.""" - mock_create.side_effect = utils_models.ExistsModelError("Model already exists") - - with pytest.raises(HTTPException) as exc_info: - await models.models_create(payload=make_model()) - - assert exc_info.value.status_code == 409 - - -class TestModelsDelete: - """Tests for the models_delete endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.models.utils_models.delete") - async def test_models_delete_returns_200_on_success(self, mock_delete): - """models_delete should return 200 status on success.""" - mock_delete.return_value = None - - result = await models.models_delete(model_provider="openai", model_id="gpt-4") - - assert result.status_code == 200 - mock_delete.assert_called_once_with(model_provider="openai", model_id="gpt-4") - - @pytest.mark.asyncio - @patch("server.api.v1.models.utils_models.delete") - async def test_models_delete_response_contains_message(self, mock_delete): - """models_delete should return message with model name.""" - mock_delete.return_value = None - - result = await models.models_delete(model_provider="openai", model_id="gpt-4") - - body = json.loads(result.body) - assert "openai/gpt-4" in body["message"] - - -class TestRouterConfiguration: - """Tests for router configuration.""" - - def test_auth_router_exists(self): - """The auth router should be defined.""" - assert hasattr(models, "auth") - - def test_auth_router_has_routes(self): - """The auth router should have registered routes.""" - routes = [route.path for route in models.auth.routes] - - assert "" in routes - assert "/supported" in routes - assert "/{model_provider}/{model_id:path}" in routes - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(models, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert models.logger.name == "endpoints.v1.models" diff --git a/test/unit/server/api/v1/test_v1_oci.py b/test/unit/server/api/v1/test_v1_oci.py deleted file mode 100644 index 4402e96c..00000000 --- a/test/unit/server/api/v1/test_v1_oci.py +++ /dev/null @@ -1,362 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/v1/oci.py -Tests for OCI configuration and resource endpoints. -""" - -# pylint: disable=too-few-public-methods - -from unittest.mock import patch -import pytest -from fastapi import HTTPException - -from server.api.v1 import oci -from server.api.utils.oci import OciException - - -class TestOciList: - """Tests for the oci_list endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.oci.utils_oci.get") - async def test_oci_list_returns_all_configs(self, mock_get, make_oci_config): - """oci_list should return all OCI configurations.""" - configs = [make_oci_config(auth_profile="DEFAULT"), make_oci_config(auth_profile="PROD")] - mock_get.return_value = configs - - result = await oci.oci_list() - - assert result == configs - mock_get.assert_called_once_with() - - @pytest.mark.asyncio - @patch("server.api.v1.oci.utils_oci.get") - async def test_oci_list_raises_404_on_value_error(self, mock_get): - """oci_list should raise 404 when ValueError occurs.""" - mock_get.side_effect = ValueError("No configs found") - - with pytest.raises(HTTPException) as exc_info: - await oci.oci_list() - - assert exc_info.value.status_code == 404 - assert "OCI:" in str(exc_info.value.detail) - - -class TestOciGet: - """Tests for the oci_get endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.oci.utils_oci.get") - async def test_oci_get_returns_single_config(self, mock_get, make_oci_config): - """oci_get should return a single OCI config by profile.""" - config = make_oci_config(auth_profile="DEFAULT") - mock_get.return_value = config - - result = await oci.oci_get(auth_profile="DEFAULT") - - assert result == config - mock_get.assert_called_once_with(auth_profile="DEFAULT") - - @pytest.mark.asyncio - @patch("server.api.v1.oci.utils_oci.get") - async def test_oci_get_raises_404_when_not_found(self, mock_get): - """oci_get should raise 404 when profile not found.""" - mock_get.side_effect = ValueError("Profile not found") - - with pytest.raises(HTTPException) as exc_info: - await oci.oci_get(auth_profile="NONEXISTENT") - - assert exc_info.value.status_code == 404 - - -class TestOciListRegions: - """Tests for the oci_list_regions endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.oci.oci_get") - @patch("server.api.v1.oci.utils_oci.get_regions") - async def test_oci_list_regions_success(self, mock_get_regions, mock_oci_get, make_oci_config): - """oci_list_regions should return list of regions.""" - config = make_oci_config() - mock_oci_get.return_value = config - mock_get_regions.return_value = ["us-ashburn-1", "us-phoenix-1"] - - result = await oci.oci_list_regions(auth_profile="DEFAULT") - - assert result == ["us-ashburn-1", "us-phoenix-1"] - mock_get_regions.assert_called_once_with(config) - - @pytest.mark.asyncio - @patch("server.api.v1.oci.oci_get") - @patch("server.api.v1.oci.utils_oci.get_regions") - async def test_oci_list_regions_raises_on_oci_exception(self, mock_get_regions, mock_oci_get, make_oci_config): - """oci_list_regions should raise HTTPException on OciException.""" - mock_oci_get.return_value = make_oci_config() - mock_get_regions.side_effect = OciException(status_code=401, detail="Unauthorized") - - with pytest.raises(HTTPException) as exc_info: - await oci.oci_list_regions(auth_profile="DEFAULT") - - assert exc_info.value.status_code == 401 - - -class TestOciListGenai: - """Tests for the oci_list_genai endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.oci.oci_get") - @patch("server.api.v1.oci.utils_oci.get_genai_models") - async def test_oci_list_genai_success(self, mock_get_genai, mock_oci_get, make_oci_config): - """oci_list_genai should return list of GenAI models.""" - config = make_oci_config() - mock_oci_get.return_value = config - mock_get_genai.return_value = [{"name": "cohere.command"}, {"name": "meta.llama"}] - - result = await oci.oci_list_genai(auth_profile="DEFAULT") - - assert len(result) == 2 - mock_get_genai.assert_called_once_with(config, regional=False) - - @pytest.mark.asyncio - @patch("server.api.v1.oci.oci_get") - @patch("server.api.v1.oci.utils_oci.get_genai_models") - async def test_oci_list_genai_raises_on_oci_exception(self, mock_get_genai, mock_oci_get, make_oci_config): - """oci_list_genai should raise HTTPException on OciException.""" - mock_oci_get.return_value = make_oci_config() - mock_get_genai.side_effect = OciException(status_code=403, detail="Forbidden") - - with pytest.raises(HTTPException) as exc_info: - await oci.oci_list_genai(auth_profile="DEFAULT") - - assert exc_info.value.status_code == 403 - - -class TestOciListCompartments: - """Tests for the oci_list_compartments endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.oci.oci_get") - @patch("server.api.v1.oci.utils_oci.get_compartments") - async def test_oci_list_compartments_success(self, mock_get_compartments, mock_oci_get, make_oci_config): - """oci_list_compartments should return compartment hierarchy.""" - config = make_oci_config() - mock_oci_get.return_value = config - compartments = {"root": {"name": "root", "children": []}} - mock_get_compartments.return_value = compartments - - result = await oci.oci_list_compartments(auth_profile="DEFAULT") - - assert result == compartments - mock_get_compartments.assert_called_once_with(config) - - @pytest.mark.asyncio - @patch("server.api.v1.oci.oci_get") - @patch("server.api.v1.oci.utils_oci.get_compartments") - async def test_oci_list_compartments_raises_on_oci_exception( - self, mock_get_compartments, mock_oci_get, make_oci_config - ): - """oci_list_compartments should raise HTTPException on OciException.""" - mock_oci_get.return_value = make_oci_config() - mock_get_compartments.side_effect = OciException(status_code=500, detail="Internal error") - - with pytest.raises(HTTPException) as exc_info: - await oci.oci_list_compartments(auth_profile="DEFAULT") - - assert exc_info.value.status_code == 500 - - -class TestOciListBuckets: - """Tests for the oci_list_buckets endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.oci.oci_get") - @patch("server.api.v1.oci.utils_oci.get_buckets") - async def test_oci_list_buckets_success(self, mock_get_buckets, mock_oci_get, make_oci_config): - """oci_list_buckets should return list of buckets.""" - config = make_oci_config() - mock_oci_get.return_value = config - mock_get_buckets.return_value = ["bucket1", "bucket2"] - compartment_ocid = "ocid1.compartment.oc1..aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" - - result = await oci.oci_list_buckets(auth_profile="DEFAULT", compartment_ocid=compartment_ocid) - - assert result == ["bucket1", "bucket2"] - mock_get_buckets.assert_called_once_with(compartment_ocid, config) - - @pytest.mark.asyncio - @patch("server.api.v1.oci.oci_get") - @patch("server.api.v1.oci.utils_oci.get_buckets") - async def test_oci_list_buckets_raises_on_oci_exception(self, mock_get_buckets, mock_oci_get, make_oci_config): - """oci_list_buckets should raise HTTPException on OciException.""" - mock_oci_get.return_value = make_oci_config() - mock_get_buckets.side_effect = OciException(status_code=404, detail="Bucket not found") - compartment_ocid = "ocid1.compartment.oc1..aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" - - with pytest.raises(HTTPException) as exc_info: - await oci.oci_list_buckets(auth_profile="DEFAULT", compartment_ocid=compartment_ocid) - - assert exc_info.value.status_code == 404 - - -class TestOciListBucketObjects: - """Tests for the oci_list_bucket_objects endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.oci.oci_get") - @patch("server.api.v1.oci.utils_oci.get_bucket_objects") - async def test_oci_list_bucket_objects_success(self, mock_get_objects, mock_oci_get, make_oci_config): - """oci_list_bucket_objects should return list of objects.""" - config = make_oci_config() - mock_oci_get.return_value = config - mock_get_objects.return_value = ["file1.pdf", "file2.txt"] - - result = await oci.oci_list_bucket_objects(auth_profile="DEFAULT", bucket_name="my-bucket") - - assert result == ["file1.pdf", "file2.txt"] - mock_get_objects.assert_called_once_with("my-bucket", config) - - @pytest.mark.asyncio - @patch("server.api.v1.oci.oci_get") - @patch("server.api.v1.oci.utils_oci.get_bucket_objects") - async def test_oci_list_bucket_objects_raises_on_oci_exception( - self, mock_get_objects, mock_oci_get, make_oci_config - ): - """oci_list_bucket_objects should raise HTTPException on OciException.""" - mock_oci_get.return_value = make_oci_config() - mock_get_objects.side_effect = OciException(status_code=403, detail="Access denied") - - with pytest.raises(HTTPException) as exc_info: - await oci.oci_list_bucket_objects(auth_profile="DEFAULT", bucket_name="my-bucket") - - assert exc_info.value.status_code == 403 - - -class TestOciProfileUpdate: - """Tests for the oci_profile_update endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.oci.oci_get") - @patch("server.api.v1.oci.utils_oci.get_namespace") - async def test_oci_profile_update_success(self, mock_get_namespace, mock_oci_get, make_oci_config): - """oci_profile_update should update and return config.""" - config = make_oci_config(auth_profile="DEFAULT") - mock_oci_get.return_value = config - mock_get_namespace.return_value = "test-namespace" - - payload = make_oci_config(auth_profile="DEFAULT", genai_region="us-phoenix-1") - - result = await oci.oci_profile_update(auth_profile="DEFAULT", payload=payload) - - assert result.namespace == "test-namespace" - assert result.genai_region == "us-phoenix-1" - - @pytest.mark.asyncio - @patch("server.api.v1.oci.oci_get") - @patch("server.api.v1.oci.utils_oci.get_namespace") - async def test_oci_profile_update_raises_on_oci_exception(self, mock_get_namespace, mock_oci_get, make_oci_config): - """oci_profile_update should raise HTTPException on OciException.""" - config = make_oci_config() - mock_oci_get.return_value = config - mock_get_namespace.side_effect = OciException(status_code=401, detail="Invalid credentials") - - with pytest.raises(HTTPException) as exc_info: - await oci.oci_profile_update(auth_profile="DEFAULT", payload=make_oci_config()) - - assert exc_info.value.status_code == 401 - assert config.namespace is None - - -class TestOciDownloadObjects: - """Tests for the oci_download_objects endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.oci.oci_get") - @patch("server.api.v1.oci.utils_embed.get_temp_directory") - @patch("server.api.v1.oci.utils_oci.get_object") - async def test_oci_download_objects_success( - self, mock_get_object, mock_get_temp_dir, mock_oci_get, make_oci_config, tmp_path - ): - """oci_download_objects should download files and return list.""" - config = make_oci_config() - mock_oci_get.return_value = config - mock_get_temp_dir.return_value = tmp_path - - # Create test files - (tmp_path / "file1.pdf").touch() - (tmp_path / "file2.txt").touch() - - result = await oci.oci_download_objects( - bucket_name="my-bucket", - auth_profile="DEFAULT", - request=["file1.pdf", "file2.txt"], - client="test_client", - ) - - assert result.status_code == 200 - assert mock_get_object.call_count == 2 - - -class TestOciCreateGenaiModels: - """Tests for the oci_create_genai_models endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.oci.oci_get") - @patch("server.api.v1.oci.utils_models.create_genai") - async def test_oci_create_genai_models_success(self, mock_create_genai, mock_oci_get, make_oci_config, make_model): - """oci_create_genai_models should create and return models.""" - config = make_oci_config() - mock_oci_get.return_value = config - models_list = [make_model(model_id="cohere.command", provider="oci")] - mock_create_genai.return_value = models_list - - result = await oci.oci_create_genai_models(auth_profile="DEFAULT") - - assert result == models_list - mock_create_genai.assert_called_once_with(config) - - @pytest.mark.asyncio - @patch("server.api.v1.oci.oci_get") - @patch("server.api.v1.oci.utils_models.create_genai") - async def test_oci_create_genai_models_raises_on_oci_exception( - self, mock_create_genai, mock_oci_get, make_oci_config - ): - """oci_create_genai_models should raise HTTPException on OciException.""" - mock_oci_get.return_value = make_oci_config() - mock_create_genai.side_effect = OciException(status_code=500, detail="GenAI service error") - - with pytest.raises(HTTPException) as exc_info: - await oci.oci_create_genai_models(auth_profile="DEFAULT") - - assert exc_info.value.status_code == 500 - - -class TestRouterConfiguration: - """Tests for router configuration.""" - - def test_auth_router_exists(self): - """The auth router should be defined.""" - assert hasattr(oci, "auth") - - def test_auth_router_has_routes(self): - """The auth router should have registered routes.""" - routes = [route.path for route in oci.auth.routes] - - assert "" in routes - assert "/{auth_profile}" in routes - assert "/regions/{auth_profile}" in routes - assert "/genai/{auth_profile}" in routes - assert "/compartments/{auth_profile}" in routes - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(oci, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert oci.logger.name == "endpoints.v1.oci" diff --git a/test/unit/server/api/v1/test_v1_probes.py b/test/unit/server/api/v1/test_v1_probes.py deleted file mode 100644 index e716a5ff..00000000 --- a/test/unit/server/api/v1/test_v1_probes.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/v1/probes.py -Tests for Kubernetes health probe endpoints. -""" - -import asyncio -from unittest.mock import MagicMock - -import pytest - -from server.api.v1 import probes - - -class TestGetMcp: - """Tests for the get_mcp dependency function.""" - - def test_get_mcp_returns_fastmcp_app(self): - """get_mcp should return the FastMCP app from request state.""" - mock_request = MagicMock() - mock_fastmcp = MagicMock() - mock_request.app.state.fastmcp_app = mock_fastmcp - - result = probes.get_mcp(mock_request) - - assert result == mock_fastmcp - - def test_get_mcp_accesses_correct_state_attribute(self): - """get_mcp should access app.state.fastmcp_app.""" - mock_request = MagicMock() - - probes.get_mcp(mock_request) - - _ = mock_request.app.state.fastmcp_app # Verify attribute access - - -class TestLivenessProbe: - """Tests for the liveness_probe endpoint.""" - - @pytest.mark.asyncio - async def test_liveness_probe_returns_alive(self): - """liveness_probe should return alive status.""" - result = await probes.liveness_probe() - - assert result == {"status": "alive"} - - @pytest.mark.asyncio - async def test_liveness_probe_is_async(self): - """liveness_probe should be an async function.""" - assert asyncio.iscoroutinefunction(probes.liveness_probe) - - -class TestReadinessProbe: - """Tests for the readiness_probe endpoint.""" - - @pytest.mark.asyncio - async def test_readiness_probe_returns_ready(self): - """readiness_probe should return ready status.""" - result = await probes.readiness_probe() - - assert result == {"status": "ready"} - - @pytest.mark.asyncio - async def test_readiness_probe_is_async(self): - """readiness_probe should be an async function.""" - assert asyncio.iscoroutinefunction(probes.readiness_probe) - - -class TestMcpHealthz: - """Tests for the mcp_healthz endpoint.""" - - def test_mcp_healthz_returns_ready_status(self): - """mcp_healthz should return ready status with server info.""" - mock_fastmcp = MagicMock() - mock_fastmcp.__dict__["_mcp_server"] = MagicMock() - mock_fastmcp.__dict__["_mcp_server"].__dict__ = { - "name": "test-server", - "version": "1.0.0", - } - mock_fastmcp.available_tools = ["tool1", "tool2"] - - result = probes.mcp_healthz(mock_fastmcp) - - assert result["status"] == "ready" - assert result["name"] == "test-server" - assert result["version"] == "1.0.0" - assert result["available_tools"] == 2 - - def test_mcp_healthz_returns_not_ready_when_none(self): - """mcp_healthz should return not ready when mcp_engine is None.""" - result = probes.mcp_healthz(None) - - assert result["status"] == "not ready" - - def test_mcp_healthz_with_no_available_tools(self): - """mcp_healthz should handle missing available_tools attribute.""" - mock_fastmcp = MagicMock(spec=[]) # No available_tools attribute - mock_fastmcp.__dict__["_mcp_server"] = MagicMock() - mock_fastmcp.__dict__["_mcp_server"].__dict__ = { - "name": "test-server", - "version": "1.0.0", - } - - result = probes.mcp_healthz(mock_fastmcp) - - assert result["status"] == "ready" - assert result["available_tools"] == 0 - - def test_mcp_healthz_is_not_async(self): - """mcp_healthz should be a sync function.""" - assert not asyncio.iscoroutinefunction(probes.mcp_healthz) - - -class TestRouterConfiguration: - """Tests for router configuration.""" - - def test_noauth_router_exists(self): - """The noauth router should be defined.""" - assert hasattr(probes, "noauth") - - def test_noauth_router_has_routes(self): - """The noauth router should have registered routes.""" - routes = [route.path for route in probes.noauth.routes] - - assert "/liveness" in routes - assert "/readiness" in routes - assert "/mcp/healthz" in routes diff --git a/test/unit/server/api/v1/test_v1_settings.py b/test/unit/server/api/v1/test_v1_settings.py deleted file mode 100644 index 348613a1..00000000 --- a/test/unit/server/api/v1/test_v1_settings.py +++ /dev/null @@ -1,326 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/v1/settings.py -Tests for client settings management endpoints. -""" - -from unittest.mock import patch, MagicMock -from io import BytesIO -import json -import pytest -from fastapi import HTTPException, UploadFile -from fastapi.responses import JSONResponse - -from server.api.v1 import settings - - -class TestSettingsGet: - """Tests for the settings_get endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.settings.utils_settings.get_client") - async def test_settings_get_returns_client_settings(self, mock_get_client, make_settings): - """settings_get should return client settings.""" - client_settings = make_settings(client="test_client") - mock_get_client.return_value = client_settings - - mock_request = MagicMock() - - result = await settings.settings_get( - request=mock_request, client="test_client", full_config=False, incl_sensitive=False, incl_readonly=False - ) - - assert result == client_settings - mock_get_client.assert_called_once_with("test_client") - - @pytest.mark.asyncio - @patch("server.api.v1.settings.utils_settings.get_client") - async def test_settings_get_raises_404_when_not_found(self, mock_get_client): - """settings_get should raise 404 when client not found.""" - mock_get_client.side_effect = ValueError("Client not found") - - mock_request = MagicMock() - - with pytest.raises(HTTPException) as exc_info: - await settings.settings_get( - request=mock_request, - client="nonexistent", - full_config=False, - incl_sensitive=False, - incl_readonly=False, - ) - - assert exc_info.value.status_code == 404 - - @pytest.mark.asyncio - @patch("server.api.v1.settings.utils_settings.get_client") - @patch("server.api.v1.settings.utils_settings.get_server") - async def test_settings_get_full_config(self, mock_get_server, mock_get_client, make_settings, mock_fastmcp): - """settings_get should return full config when requested.""" - client_settings = make_settings(client="test_client") - mock_get_client.return_value = client_settings - mock_get_server.return_value = { - "database_configs": [], - "model_configs": [], - "oci_configs": [], - "prompt_configs": [], - } - - mock_request = MagicMock() - mock_request.app.state.fastmcp_app = mock_fastmcp - - result = await settings.settings_get( - request=mock_request, client="test_client", full_config=True, incl_sensitive=False, incl_readonly=False - ) - - assert isinstance(result, JSONResponse) - mock_get_server.assert_called_once() - - -class TestSettingsUpdate: - """Tests for the settings_update endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.settings.utils_settings.update_client") - async def test_settings_update_success(self, mock_update_client, make_settings): - """settings_update should update and return settings.""" - updated_settings = make_settings(client="test_client", temperature=0.9) - mock_update_client.return_value = updated_settings - - payload = make_settings(client="test_client", temperature=0.9) - - result = await settings.settings_update(payload=payload, client="test_client") - - assert result == updated_settings - mock_update_client.assert_called_once_with(payload, "test_client") - - @pytest.mark.asyncio - @patch("server.api.v1.settings.utils_settings.update_client") - async def test_settings_update_raises_404_when_not_found(self, mock_update_client, make_settings): - """settings_update should raise 404 when client not found.""" - mock_update_client.side_effect = ValueError("Client not found") - - payload = make_settings(client="nonexistent") - - with pytest.raises(HTTPException) as exc_info: - await settings.settings_update(payload=payload, client="nonexistent") - - assert exc_info.value.status_code == 404 - - -class TestSettingsCreate: - """Tests for the settings_create endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.settings.utils_settings.create_client") - async def test_settings_create_success(self, mock_create_client, make_settings): - """settings_create should create and return new settings.""" - new_settings = make_settings(client="new_client") - mock_create_client.return_value = new_settings - - result = await settings.settings_create(client="new_client") - - assert result == new_settings - mock_create_client.assert_called_once_with("new_client") - - @pytest.mark.asyncio - @patch("server.api.v1.settings.utils_settings.create_client") - async def test_settings_create_raises_409_when_exists(self, mock_create_client): - """settings_create should raise 409 when client already exists.""" - mock_create_client.side_effect = ValueError("Client already exists") - - with pytest.raises(HTTPException) as exc_info: - await settings.settings_create(client="existing_client") - - assert exc_info.value.status_code == 409 - - -class TestLoadSettingsFromFile: - """Tests for the load_settings_from_file endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.settings.utils_settings.create_client") - @patch("server.api.v1.settings.utils_settings.load_config_from_json_data") - async def test_load_settings_from_file_success(self, mock_load_config, mock_create_client): - """load_settings_from_file should load config from JSON file.""" - mock_create_client.return_value = MagicMock() - mock_load_config.return_value = None - - config_data = {"client_settings": {"client": "test"}, "database_configs": []} - file_content = json.dumps(config_data).encode() - mock_file = UploadFile(file=BytesIO(file_content), filename="config.json") - - result = await settings.load_settings_from_file(client="test_client", file=mock_file) - - assert result["message"] == "Configuration loaded successfully." - mock_load_config.assert_called_once() - - @pytest.mark.asyncio - @patch("server.api.v1.settings.utils_settings.create_client") - async def test_load_settings_from_file_wrong_extension(self, mock_create_client): - """load_settings_from_file should raise error for non-JSON files. - - Note: Due to the generic exception handler in the source code, - HTTPException(400) is caught and wrapped in HTTPException(500). - """ - mock_create_client.return_value = MagicMock() - - mock_file = UploadFile(file=BytesIO(b"data"), filename="config.txt") - - with pytest.raises(HTTPException) as exc_info: - await settings.load_settings_from_file(client="test_client", file=mock_file) - - # The 400 HTTPException gets caught by generic exception handler and wrapped in 500 - assert exc_info.value.status_code == 500 - assert "JSON" in str(exc_info.value.detail) - - @pytest.mark.asyncio - @patch("server.api.v1.settings.utils_settings.create_client") - async def test_load_settings_from_file_invalid_json(self, mock_create_client): - """load_settings_from_file should raise 400 for invalid JSON.""" - mock_create_client.return_value = MagicMock() - - mock_file = UploadFile(file=BytesIO(b"not valid json"), filename="config.json") - - with pytest.raises(HTTPException) as exc_info: - await settings.load_settings_from_file(client="test_client", file=mock_file) - - assert exc_info.value.status_code == 400 - assert "Invalid JSON" in str(exc_info.value.detail) - - @pytest.mark.asyncio - @patch("server.api.v1.settings.utils_settings.create_client") - @patch("server.api.v1.settings.utils_settings.load_config_from_json_data") - async def test_load_settings_from_file_key_error(self, mock_load_config, mock_create_client): - """load_settings_from_file should raise 400 on KeyError.""" - mock_create_client.return_value = MagicMock() - mock_load_config.side_effect = KeyError("Missing required key") - - config_data = {"incomplete": "data"} - file_content = json.dumps(config_data).encode() - mock_file = UploadFile(file=BytesIO(file_content), filename="config.json") - - with pytest.raises(HTTPException) as exc_info: - await settings.load_settings_from_file(client="test_client", file=mock_file) - - assert exc_info.value.status_code == 400 - - @pytest.mark.asyncio - @patch("server.api.v1.settings.utils_settings.create_client") - @patch("server.api.v1.settings.utils_settings.load_config_from_json_data") - async def test_load_settings_from_file_handles_existing_client(self, mock_load_config, mock_create_client): - """load_settings_from_file should continue if client already exists.""" - mock_create_client.side_effect = ValueError("Client already exists") - mock_load_config.return_value = None - - config_data = {"client_settings": {"client": "test"}} - file_content = json.dumps(config_data).encode() - mock_file = UploadFile(file=BytesIO(file_content), filename="config.json") - - result = await settings.load_settings_from_file(client="test_client", file=mock_file) - - assert result["message"] == "Configuration loaded successfully." - - -class TestLoadSettingsFromJson: - """Tests for the load_settings_from_json endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.settings.utils_settings.create_client") - @patch("server.api.v1.settings.utils_settings.load_config_from_json_data") - async def test_load_settings_from_json_success(self, mock_load_config, mock_create_client, make_configuration): - """load_settings_from_json should load config from JSON payload.""" - mock_create_client.return_value = MagicMock() - mock_load_config.return_value = None - - payload = make_configuration(client="test_client") - - result = await settings.load_settings_from_json(client="test_client", payload=payload) - - assert result["message"] == "Configuration loaded successfully." - mock_load_config.assert_called_once() - - @pytest.mark.asyncio - @patch("server.api.v1.settings.utils_settings.create_client") - @patch("server.api.v1.settings.utils_settings.load_config_from_json_data") - async def test_load_settings_from_json_key_error(self, mock_load_config, mock_create_client, make_configuration): - """load_settings_from_json should raise 400 on KeyError.""" - mock_create_client.return_value = MagicMock() - mock_load_config.side_effect = KeyError("Missing required key") - - payload = make_configuration(client="test_client") - - with pytest.raises(HTTPException) as exc_info: - await settings.load_settings_from_json(client="test_client", payload=payload) - - assert exc_info.value.status_code == 400 - - @pytest.mark.asyncio - @patch("server.api.v1.settings.utils_settings.create_client") - @patch("server.api.v1.settings.utils_settings.load_config_from_json_data") - async def test_load_settings_from_json_handles_existing_client( - self, mock_load_config, mock_create_client, make_configuration - ): - """load_settings_from_json should continue if client already exists.""" - mock_create_client.side_effect = ValueError("Client already exists") - mock_load_config.return_value = None - - payload = make_configuration(client="test_client") - - result = await settings.load_settings_from_json(client="test_client", payload=payload) - - assert result["message"] == "Configuration loaded successfully." - - -class TestIncludeParams: # pylint: disable=protected-access - """Tests for the include parameter dependencies.""" - - def test_incl_sensitive_param_default(self): - """_incl_sensitive_param should default to False.""" - result = settings._incl_sensitive_param(incl_sensitive=False) - assert result is False - - def test_incl_sensitive_param_true(self): - """_incl_sensitive_param should return True when set.""" - result = settings._incl_sensitive_param(incl_sensitive=True) - assert result is True - - def test_incl_readonly_param_default(self): - """_incl_readonly_param should default to False.""" - result = settings._incl_readonly_param(incl_readonly=False) - assert result is False - - def test_incl_readonly_param_true(self): - """_incl_readonly_param should return True when set.""" - result = settings._incl_readonly_param(incl_readonly=True) - assert result is True - - -class TestRouterConfiguration: - """Tests for router configuration.""" - - def test_auth_router_exists(self): - """The auth router should be defined.""" - assert hasattr(settings, "auth") - - def test_auth_router_has_routes(self): - """The auth router should have registered routes.""" - routes = [route.path for route in settings.auth.routes] - - assert "" in routes # Get, Update, Create - assert "/load/file" in routes - assert "/load/json" in routes - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(settings, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert settings.logger.name == "endpoints.v1.settings" diff --git a/test/unit/server/api/v1/test_v1_testbed.py b/test/unit/server/api/v1/test_v1_testbed.py deleted file mode 100644 index 8ba6d12c..00000000 --- a/test/unit/server/api/v1/test_v1_testbed.py +++ /dev/null @@ -1,623 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/api/v1/testbed.py -Tests for Q&A testbed and evaluation endpoints. -""" -# pylint: disable=protected-access,too-few-public-methods,too-many-arguments -# pylint: disable=too-many-positional-arguments,too-many-locals - -from unittest.mock import patch, MagicMock, AsyncMock -from io import BytesIO -import pytest -from fastapi import HTTPException, UploadFile -import litellm - -from server.api.v1 import testbed -from common.schema import TestSets, TestSetQA, Evaluation, EvaluationReport - - -class TestTestbedTestsets: - """Tests for the testbed_testsets endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.testbed.utils_databases.get_client_database") - @patch("server.api.v1.testbed.utils_testbed.get_testsets") - async def test_testbed_testsets_returns_list( - self, mock_get_testsets, mock_get_db, mock_db_connection - ): - """testbed_testsets should return list of testsets.""" - mock_db = MagicMock() - mock_db.connection = mock_db_connection - mock_get_db.return_value = mock_db - - mock_testsets = [ - TestSets(tid="TS001", name="Test Set 1", created="2024-01-01"), - TestSets(tid="TS002", name="Test Set 2", created="2024-01-02"), - ] - mock_get_testsets.return_value = mock_testsets - - result = await testbed.testbed_testsets(client="test_client") - - assert result == mock_testsets - mock_get_testsets.assert_called_once_with(db_conn=mock_db_connection) - - @pytest.mark.asyncio - @patch("server.api.v1.testbed.utils_databases.get_client_database") - @patch("server.api.v1.testbed.utils_testbed.get_testsets") - async def test_testbed_testsets_empty_list( - self, mock_get_testsets, mock_get_db, mock_db_connection - ): - """testbed_testsets should return empty list when no testsets.""" - mock_db = MagicMock() - mock_db.connection = mock_db_connection - mock_get_db.return_value = mock_db - mock_get_testsets.return_value = [] - - result = await testbed.testbed_testsets(client="test_client") - - assert result == [] - - -class TestTestbedEvaluations: - """Tests for the testbed_evaluations endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.testbed.utils_databases.get_client_database") - @patch("server.api.v1.testbed.utils_testbed.get_evaluations") - async def test_testbed_evaluations_returns_list( - self, mock_get_evals, mock_get_db, mock_db_connection - ): - """testbed_evaluations should return list of evaluations.""" - mock_db = MagicMock() - mock_db.connection = mock_db_connection - mock_get_db.return_value = mock_db - - mock_evals = [ - Evaluation(eid="EV001", evaluated="2024-01-01", correctness=0.85), - Evaluation(eid="EV002", evaluated="2024-01-02", correctness=0.90), - ] - mock_get_evals.return_value = mock_evals - - result = await testbed.testbed_evaluations(tid="ts001", client="test_client") - - assert result == mock_evals - mock_get_evals.assert_called_once_with(db_conn=mock_db_connection, tid="TS001") - - @pytest.mark.asyncio - @patch("server.api.v1.testbed.utils_databases.get_client_database") - @patch("server.api.v1.testbed.utils_testbed.get_evaluations") - async def test_testbed_evaluations_uppercases_tid( - self, mock_get_evals, mock_get_db, mock_db_connection - ): - """testbed_evaluations should uppercase the tid.""" - mock_db = MagicMock() - mock_db.connection = mock_db_connection - mock_get_db.return_value = mock_db - mock_get_evals.return_value = [] - - await testbed.testbed_evaluations(tid="lowercase", client="test_client") - - mock_get_evals.assert_called_once_with(db_conn=mock_db_connection, tid="LOWERCASE") - - -class TestTestbedEvaluation: - """Tests for the testbed_evaluation endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.testbed.utils_databases.get_client_database") - @patch("server.api.v1.testbed.utils_testbed.process_report") - async def test_testbed_evaluation_returns_report( - self, mock_process_report, mock_get_db, mock_db_connection - ): - """testbed_evaluation should return evaluation report.""" - mock_db = MagicMock() - mock_db.connection = mock_db_connection - mock_get_db.return_value = mock_db - - mock_report = MagicMock(spec=EvaluationReport) - mock_process_report.return_value = mock_report - - result = await testbed.testbed_evaluation(eid="ev001", client="test_client") - - assert result == mock_report - mock_process_report.assert_called_once_with(db_conn=mock_db_connection, eid="EV001") - - -class TestTestbedTestsetQa: - """Tests for the testbed_testset_qa endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.testbed.utils_databases.get_client_database") - @patch("server.api.v1.testbed.utils_testbed.get_testset_qa") - async def test_testbed_testset_qa_returns_data( - self, mock_get_qa, mock_get_db, mock_db_connection - ): - """testbed_testset_qa should return Q&A data.""" - mock_db = MagicMock() - mock_db.connection = mock_db_connection - mock_get_db.return_value = mock_db - - mock_qa = TestSetQA(qa_data=[{"question": "Q1", "answer": "A1"}]) - mock_get_qa.return_value = mock_qa - - result = await testbed.testbed_testset_qa(tid="ts001", client="test_client") - - assert result == mock_qa - mock_get_qa.assert_called_once_with(db_conn=mock_db_connection, tid="TS001") - - -class TestTestbedDeleteTestset: - """Tests for the testbed_delete_testset endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.testbed.utils_databases.get_client_database") - @patch("server.api.v1.testbed.utils_testbed.delete_qa") - async def test_testbed_delete_testset_success( - self, mock_delete_qa, mock_get_db, mock_db_connection - ): - """testbed_delete_testset should delete and return success.""" - mock_db = MagicMock() - mock_db.connection = mock_db_connection - mock_get_db.return_value = mock_db - mock_delete_qa.return_value = None - - result = await testbed.testbed_delete_testset(tid="ts001", client="test_client") - - assert result.status_code == 200 - mock_delete_qa.assert_called_once_with(mock_db_connection, "TS001") - - -class TestTestbedUpsertTestsets: - """Tests for the testbed_upsert_testsets endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.testbed.utils_databases.get_client_database") - @patch("server.api.v1.testbed.utils_testbed.jsonl_to_json_content") - @patch("server.api.v1.testbed.utils_testbed.upsert_qa") - @patch("server.api.v1.testbed.testbed_testset_qa") - async def test_testbed_upsert_testsets_success( - self, mock_testset_qa, mock_upsert, mock_jsonl, mock_get_db, mock_db_connection - ): - """testbed_upsert_testsets should upload and return Q&A.""" - mock_db = MagicMock() - mock_db.connection = mock_db_connection - mock_get_db.return_value = mock_db - mock_jsonl.return_value = [{"question": "Q1", "answer": "A1"}] - mock_upsert.return_value = "TS001" - mock_testset_qa.return_value = TestSetQA(qa_data=[{"question": "Q1"}]) - - mock_file = UploadFile(file=BytesIO(b'{"question": "Q1"}'), filename="test.jsonl") - - result = await testbed.testbed_upsert_testsets( - files=[mock_file], name="Test Set", tid=None, client="test_client" - ) - - assert isinstance(result, TestSetQA) - mock_db_connection.commit.assert_called_once() - - @pytest.mark.asyncio - @patch("server.api.v1.testbed.utils_databases.get_client_database") - @patch("server.api.v1.testbed.utils_testbed.jsonl_to_json_content") - async def test_testbed_upsert_testsets_handles_exception( - self, mock_jsonl, mock_get_db, mock_db_connection - ): - """testbed_upsert_testsets should raise 500 on exception.""" - mock_db = MagicMock() - mock_db.connection = mock_db_connection - mock_get_db.return_value = mock_db - mock_jsonl.side_effect = Exception("Parse error") - - mock_file = UploadFile(file=BytesIO(b"invalid"), filename="test.jsonl") - - with pytest.raises(HTTPException) as exc_info: - await testbed.testbed_upsert_testsets( - files=[mock_file], name="Test", tid=None, client="test_client" - ) - - assert exc_info.value.status_code == 500 - - -class TestHandleTestsetError: - """Tests for the _handle_testset_error helper function.""" - - def test_handle_testset_error_key_error_columns(self, tmp_path): - """_handle_testset_error should raise 400 for column KeyError.""" - ex = KeyError("None of ['col1'] are in the columns") - - with pytest.raises(HTTPException) as exc_info: - testbed._handle_testset_error(ex, tmp_path, "test-model") - - assert exc_info.value.status_code == 400 - assert "test-model" in str(exc_info.value.detail) - - def test_handle_testset_error_value_error(self, tmp_path): - """_handle_testset_error should raise 400 for ValueError.""" - ex = ValueError("Invalid value") - - with pytest.raises(HTTPException) as exc_info: - testbed._handle_testset_error(ex, tmp_path, "test-model") - - assert exc_info.value.status_code == 400 - - def test_handle_testset_error_api_connection_error(self, tmp_path): - """_handle_testset_error should raise 424 for API connection error.""" - ex = litellm.APIConnectionError( - message="Connection failed", llm_provider="openai", model="gpt-4" - ) - - with pytest.raises(HTTPException) as exc_info: - testbed._handle_testset_error(ex, tmp_path, "test-model") - - assert exc_info.value.status_code == 424 - - def test_handle_testset_error_unknown_exception(self, tmp_path): - """_handle_testset_error should raise 500 for unknown exceptions.""" - ex = RuntimeError("Unknown error") - - with pytest.raises(HTTPException) as exc_info: - testbed._handle_testset_error(ex, tmp_path, "test-model") - - assert exc_info.value.status_code == 500 - - def test_handle_testset_error_other_key_error(self, tmp_path): - """_handle_testset_error should re-raise other KeyErrors.""" - ex = KeyError("some_other_key") - - with pytest.raises(KeyError): - testbed._handle_testset_error(ex, tmp_path, "test-model") - - -class TestTestbedGenerateQa: - """Tests for the testbed_generate_qa endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.testbed.utils_oci.get") - async def test_testbed_generate_qa_raises_400_on_value_error(self, mock_oci_get): - """testbed_generate_qa should raise 400 on ValueError.""" - mock_oci_get.side_effect = ValueError("Invalid OCI config") - - mock_file = UploadFile(file=BytesIO(b"content"), filename="test.txt") - - with pytest.raises(HTTPException) as exc_info: - await testbed.testbed_generate_qa( - files=[mock_file], - name="Test", - ll_model="gpt-4", - embed_model="text-embedding-3", - questions=2, - client="test_client", - ) - - assert exc_info.value.status_code == 400 - - -class TestRouterConfiguration: - """Tests for router configuration.""" - - def test_auth_router_exists(self): - """The auth router should be defined.""" - assert hasattr(testbed, "auth") - - def test_auth_router_has_routes(self): - """The auth router should have registered routes.""" - routes = [route.path for route in testbed.auth.routes] - - assert "/testsets" in routes - assert "/evaluations" in routes - assert "/evaluation" in routes - assert "/testset_qa" in routes - assert "/testset_delete/{tid}" in routes - assert "/testset_load" in routes - assert "/testset_generate" in routes - assert "/evaluate" in routes - - -class TestProcessFileForTestset: - """Tests for the _process_file_for_testset helper function.""" - - @pytest.mark.asyncio - @patch("server.api.v1.testbed.utils_testbed.load_and_split") - @patch("server.api.v1.testbed.utils_testbed.build_knowledge_base") - async def test_process_file_writes_and_processes( - self, mock_build_kb, mock_load_split, tmp_path - ): - """_process_file_for_testset should write file and build knowledge base.""" - mock_load_split.return_value = ["node1", "node2"] - mock_testset = MagicMock() - - # Make save create an actual file (function reads it after save) - def save_side_effect(path): - with open(path, "w", encoding="utf-8") as f: - f.write('{"question": "generated"}\n') - - mock_testset.save = save_side_effect - mock_build_kb.return_value = mock_testset - - mock_file = MagicMock() - mock_file.read = AsyncMock(return_value=b"file content") - mock_file.filename = "test.pdf" - - full_testsets = tmp_path / "all_testsets.jsonl" - full_testsets.touch() - - await testbed._process_file_for_testset( - file=mock_file, - temp_directory=tmp_path, - full_testsets=full_testsets, - name="TestSet", - questions=5, - ll_model="gpt-4", - embed_model="text-embedding-3", - oci_config=MagicMock(), - ) - - mock_load_split.assert_called_once() - mock_build_kb.assert_called_once() - # Verify file was created (save was called) - assert (tmp_path / "TestSet.jsonl").exists() - - @pytest.mark.asyncio - @patch("server.api.v1.testbed.utils_testbed.load_and_split") - @patch("server.api.v1.testbed.utils_testbed.build_knowledge_base") - async def test_process_file_appends_to_full_testsets( - self, mock_build_kb, mock_load_split, tmp_path - ): - """_process_file_for_testset should append to full_testsets file.""" - mock_load_split.return_value = ["node1"] - mock_testset = MagicMock() - - def save_side_effect(path): - with open(path, "w", encoding="utf-8") as f: - f.write('{"question": "Q1"}\n') - - mock_testset.save = save_side_effect - mock_build_kb.return_value = mock_testset - - mock_file = MagicMock() - mock_file.read = AsyncMock(return_value=b"content") - mock_file.filename = "test.pdf" - - full_testsets = tmp_path / "all_testsets.jsonl" - full_testsets.write_text('{"question": "existing"}\n') - - await testbed._process_file_for_testset( - file=mock_file, - temp_directory=tmp_path, - full_testsets=full_testsets, - name="TestSet", - questions=2, - ll_model="gpt-4", - embed_model="embed", - oci_config=MagicMock(), - ) - - content = full_testsets.read_text() - assert '{"question": "existing"}' in content - assert '{"question": "Q1"}' in content - - -class TestCollectTestbedAnswers: - """Tests for the _collect_testbed_answers helper function.""" - - @pytest.mark.asyncio - @patch("server.api.v1.testbed.chat.chat_post") - async def test_collect_answers_returns_agent_answers(self, mock_chat_post): - """_collect_testbed_answers should return list of AgentAnswer objects.""" - mock_chat_post.return_value = { - "choices": [{"message": {"content": "Test response"}}] - } - - mock_df = MagicMock() - mock_df.itertuples.return_value = [ - MagicMock(question="Question 1"), - MagicMock(question="Question 2"), - ] - mock_testset = MagicMock() - mock_testset.to_pandas.return_value = mock_df - - result = await testbed._collect_testbed_answers(mock_testset, "test_client") - - assert len(result) == 2 - assert result[0].message == "Test response" - assert result[1].message == "Test response" - - @pytest.mark.asyncio - @patch("server.api.v1.testbed.chat.chat_post") - async def test_collect_answers_calls_chat_for_each_question(self, mock_chat_post): - """_collect_testbed_answers should call chat endpoint for each question.""" - mock_chat_post.return_value = { - "choices": [{"message": {"content": "Response"}}] - } - - mock_df = MagicMock() - mock_df.itertuples.return_value = [ - MagicMock(question="Q1"), - MagicMock(question="Q2"), - MagicMock(question="Q3"), - ] - mock_testset = MagicMock() - mock_testset.to_pandas.return_value = mock_df - - await testbed._collect_testbed_answers(mock_testset, "client123") - - assert mock_chat_post.call_count == 3 - - @pytest.mark.asyncio - @patch("server.api.v1.testbed.chat.chat_post") - async def test_collect_answers_empty_testset(self, mock_chat_post): - """_collect_testbed_answers should return empty list for empty testset.""" - mock_df = MagicMock() - mock_df.itertuples.return_value = [] - mock_testset = MagicMock() - mock_testset.to_pandas.return_value = mock_df - - result = await testbed._collect_testbed_answers(mock_testset, "client") - - assert result == [] - mock_chat_post.assert_not_called() - - -class TestTestbedEvaluate: - """Tests for the testbed_evaluate endpoint.""" - - @pytest.mark.asyncio - @patch("server.api.v1.testbed.pickle.dumps") - @patch("server.api.v1.testbed.utils_settings.get_client") - @patch("server.api.v1.testbed.utils_databases.get_client_database") - @patch("server.api.v1.testbed.utils_testbed.get_testset_qa") - @patch("server.api.v1.testbed.utils_embed.get_temp_directory") - @patch("server.api.v1.testbed.QATestset.load") - @patch("server.api.v1.testbed.utils_oci.get") - @patch("server.api.v1.testbed.utils_models.get_litellm_config") - @patch("server.api.v1.testbed.set_llm_model") - @patch("server.api.v1.testbed.get_prompt_with_override") - @patch("server.api.v1.testbed._collect_testbed_answers") - @patch("server.api.v1.testbed.evaluate") - @patch("server.api.v1.testbed.utils_testbed.insert_evaluation") - @patch("server.api.v1.testbed.utils_testbed.process_report") - @patch("server.api.v1.testbed.shutil.rmtree") - async def test_testbed_evaluate_success( - self, - _mock_rmtree, - mock_process_report, - mock_insert_eval, - mock_evaluate, - mock_collect_answers, - mock_get_prompt, - _mock_set_llm, - mock_get_litellm, - mock_oci_get, - mock_qa_load, - mock_get_temp_dir, - mock_get_testset_qa, - mock_get_db, - mock_get_settings, - mock_pickle_dumps, - mock_db_connection, - tmp_path, - ): - """testbed_evaluate should run evaluation and return report.""" - mock_pickle_dumps.return_value = b"pickled_report" - - mock_settings = MagicMock() - mock_settings.ll_model = MagicMock() - mock_settings.vector_search = MagicMock() - mock_settings.model_dump_json.return_value = "{}" - mock_get_settings.return_value = mock_settings - - mock_db = MagicMock() - mock_db.connection = mock_db_connection - mock_get_db.return_value = mock_db - - mock_get_testset_qa.return_value = MagicMock(qa_data=[{"q": "Q1", "a": "A1"}]) - mock_get_temp_dir.return_value = tmp_path - - mock_loaded_testset = MagicMock() - mock_qa_load.return_value = mock_loaded_testset - - mock_oci_get.return_value = MagicMock() - mock_get_litellm.return_value = {"api_key": "test"} - - mock_prompt_msg = MagicMock() - mock_prompt_msg.content.text = "You are a judge." - mock_get_prompt.return_value = mock_prompt_msg - - mock_collect_answers.return_value = [MagicMock(message="Answer")] - - mock_report = MagicMock() - mock_report.correctness = 0.85 - mock_evaluate.return_value = mock_report - - mock_insert_eval.return_value = "EID123" - - mock_eval_report = MagicMock() - mock_process_report.return_value = mock_eval_report - - result = await testbed.testbed_evaluate( - tid="TS001", - judge="gpt-4", - client="test_client", - ) - - assert result == mock_eval_report - mock_settings.ll_model.chat_history = False - mock_settings.vector_search.grade = False - mock_evaluate.assert_called_once() - mock_insert_eval.assert_called_once() - mock_db_connection.commit.assert_called() - - @pytest.mark.asyncio - @patch("server.api.v1.testbed.utils_settings.get_client") - @patch("server.api.v1.testbed.utils_databases.get_client_database") - @patch("server.api.v1.testbed.utils_testbed.get_testset_qa") - @patch("server.api.v1.testbed.utils_embed.get_temp_directory") - @patch("server.api.v1.testbed.QATestset.load") - @patch("server.api.v1.testbed.utils_oci.get") - @patch("server.api.v1.testbed.utils_models.get_litellm_config") - @patch("server.api.v1.testbed.set_llm_model") - @patch("server.api.v1.testbed.get_prompt_with_override") - @patch("server.api.v1.testbed._collect_testbed_answers") - @patch("server.api.v1.testbed.evaluate") - async def test_testbed_evaluate_raises_500_on_correctness_key_error( - self, - mock_evaluate, - mock_collect_answers, - mock_get_prompt, - _mock_set_llm, - mock_get_litellm, - mock_oci_get, - mock_qa_load, - mock_get_temp_dir, - mock_get_testset_qa, - mock_get_db, - mock_get_settings, - mock_db_connection, - tmp_path, - ): - """testbed_evaluate should raise 500 when correctness key is missing.""" - mock_settings = MagicMock() - mock_settings.ll_model = MagicMock() - mock_settings.vector_search = MagicMock() - mock_get_settings.return_value = mock_settings - - mock_db = MagicMock() - mock_db.connection = mock_db_connection - mock_get_db.return_value = mock_db - - mock_get_testset_qa.return_value = MagicMock(qa_data=[{"q": "Q1"}]) - mock_get_temp_dir.return_value = tmp_path - - mock_qa_load.return_value = MagicMock() - mock_oci_get.return_value = MagicMock() - mock_get_litellm.return_value = {} - - mock_prompt_msg = MagicMock() - mock_prompt_msg.content.text = "Judge prompt" - mock_get_prompt.return_value = mock_prompt_msg - - mock_collect_answers.return_value = [] - mock_evaluate.side_effect = KeyError("correctness") - - with pytest.raises(HTTPException) as exc_info: - await testbed.testbed_evaluate( - tid="TS001", - judge="gpt-4", - client="test_client", - ) - - assert exc_info.value.status_code == 500 - assert "correctness" in str(exc_info.value.detail) - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(testbed, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert testbed.logger.name == "endpoints.v1.testbed" diff --git a/test/unit/server/bootstrap/__init__.py b/test/unit/server/bootstrap/__init__.py deleted file mode 100644 index 170366b5..00000000 --- a/test/unit/server/bootstrap/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Bootstrap unit test package diff --git a/test/unit/server/bootstrap/conftest.py b/test/unit/server/bootstrap/conftest.py deleted file mode 100644 index 9bc23743..00000000 --- a/test/unit/server/bootstrap/conftest.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Pytest fixtures for server/bootstrap unit tests. - -Re-exports shared fixtures from test.shared_fixtures and adds unit-test specific fixtures. -""" - -# pylint: disable=redefined-outer-name unused-import - -from unittest.mock import MagicMock, patch - -# Re-export shared fixtures for pytest discovery -from test.shared_fixtures import ( - make_database, - make_model, - make_oci_config, - make_ll_settings, - make_settings, - make_configuration, - temp_config_file, - reset_config_store, - clean_env, -) - -import pytest - - -################################################# -# Unit Test Specific Mock Fixtures -################################################# - - -@pytest.fixture -def mock_oci_config_parser(): - """Mock OCI config parser for testing OCI bootstrap.""" - with patch("configparser.ConfigParser") as mock_parser: - mock_instance = MagicMock() - mock_instance.sections.return_value = [] - mock_parser.return_value = mock_instance - yield mock_parser - - -@pytest.fixture -def mock_oci_config_from_file(): - """Mock oci.config.from_file for testing OCI bootstrap.""" - with patch("oci.config.from_file") as mock_from_file: - yield mock_from_file - - -@pytest.fixture -def mock_is_url_accessible(): - """Mock is_url_accessible for testing model bootstrap.""" - with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: - mock_accessible.return_value = (True, "OK") - yield mock_accessible diff --git a/test/unit/server/bootstrap/test_bootstrap_bootstrap.py b/test/unit/server/bootstrap/test_bootstrap_bootstrap.py deleted file mode 100644 index 542baee1..00000000 --- a/test/unit/server/bootstrap/test_bootstrap_bootstrap.py +++ /dev/null @@ -1,183 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/bootstrap/bootstrap.py -Tests for the main bootstrap module that coordinates all bootstrap operations. -""" - -# pylint: disable=redefined-outer-name protected-access too-few-public-methods -# pylint: disable=import-outside-toplevel - -import importlib -from unittest.mock import patch - -from server.bootstrap import bootstrap - - -class TestBootstrapModule: - """Tests for the bootstrap module initialization.""" - - def test_database_objects_is_list(self): - """DATABASE_OBJECTS should be a list.""" - with patch("server.bootstrap.databases.main") as mock_databases: - mock_databases.return_value = [] - with patch("server.bootstrap.models.main") as mock_models: - mock_models.return_value = [] - with patch("server.bootstrap.oci.main") as mock_oci: - mock_oci.return_value = [] - with patch("server.bootstrap.settings.main") as mock_settings: - mock_settings.return_value = [] - - # Reload to trigger module-level code with mocks - importlib.reload(bootstrap) - - assert isinstance(bootstrap.DATABASE_OBJECTS, list) - - def test_model_objects_is_list(self): - """MODEL_OBJECTS should be a list.""" - with patch("server.bootstrap.databases.main") as mock_databases: - mock_databases.return_value = [] - with patch("server.bootstrap.models.main") as mock_models: - mock_models.return_value = [] - with patch("server.bootstrap.oci.main") as mock_oci: - mock_oci.return_value = [] - with patch("server.bootstrap.settings.main") as mock_settings: - mock_settings.return_value = [] - - importlib.reload(bootstrap) - - assert isinstance(bootstrap.MODEL_OBJECTS, list) - - def test_oci_objects_is_list(self): - """OCI_OBJECTS should be a list.""" - with patch("server.bootstrap.databases.main") as mock_databases: - mock_databases.return_value = [] - with patch("server.bootstrap.models.main") as mock_models: - mock_models.return_value = [] - with patch("server.bootstrap.oci.main") as mock_oci: - mock_oci.return_value = [] - with patch("server.bootstrap.settings.main") as mock_settings: - mock_settings.return_value = [] - - importlib.reload(bootstrap) - - assert isinstance(bootstrap.OCI_OBJECTS, list) - - def test_settings_objects_is_list(self): - """SETTINGS_OBJECTS should be a list.""" - with patch("server.bootstrap.databases.main") as mock_databases: - mock_databases.return_value = [] - with patch("server.bootstrap.models.main") as mock_models: - mock_models.return_value = [] - with patch("server.bootstrap.oci.main") as mock_oci: - mock_oci.return_value = [] - with patch("server.bootstrap.settings.main") as mock_settings: - mock_settings.return_value = [] - - importlib.reload(bootstrap) - - assert isinstance(bootstrap.SETTINGS_OBJECTS, list) - - def test_calls_all_bootstrap_functions(self): - """Bootstrap module should call all main() functions.""" - with patch("server.bootstrap.databases.main") as mock_databases: - mock_databases.return_value = [] - with patch("server.bootstrap.models.main") as mock_models: - mock_models.return_value = [] - with patch("server.bootstrap.oci.main") as mock_oci: - mock_oci.return_value = [] - with patch("server.bootstrap.settings.main") as mock_settings: - mock_settings.return_value = [] - - importlib.reload(bootstrap) - - mock_databases.assert_called_once() - mock_models.assert_called_once() - mock_oci.assert_called_once() - mock_settings.assert_called_once() - - def test_stores_database_results(self, make_database): - """Bootstrap module should store database.main() results.""" - db1 = make_database(name="DB1") - db2 = make_database(name="DB2") - - with patch("server.bootstrap.databases.main") as mock_databases: - mock_databases.return_value = [db1, db2] - with patch("server.bootstrap.models.main") as mock_models: - mock_models.return_value = [] - with patch("server.bootstrap.oci.main") as mock_oci: - mock_oci.return_value = [] - with patch("server.bootstrap.settings.main") as mock_settings: - mock_settings.return_value = [] - - importlib.reload(bootstrap) - - assert len(bootstrap.DATABASE_OBJECTS) == 2 - assert bootstrap.DATABASE_OBJECTS[0].name == "DB1" - - def test_stores_model_results(self, make_model): - """Bootstrap module should store models.main() results.""" - model1 = make_model(model_id="model1") - model2 = make_model(model_id="model2") - - with patch("server.bootstrap.databases.main") as mock_databases: - mock_databases.return_value = [] - with patch("server.bootstrap.models.main") as mock_models: - mock_models.return_value = [model1, model2] - with patch("server.bootstrap.oci.main") as mock_oci: - mock_oci.return_value = [] - with patch("server.bootstrap.settings.main") as mock_settings: - mock_settings.return_value = [] - - importlib.reload(bootstrap) - - assert len(bootstrap.MODEL_OBJECTS) == 2 - - def test_stores_oci_results(self, make_oci_config): - """Bootstrap module should store oci.main() results.""" - oci1 = make_oci_config(auth_profile="PROFILE1") - oci2 = make_oci_config(auth_profile="PROFILE2") - - with patch("server.bootstrap.databases.main") as mock_databases: - mock_databases.return_value = [] - with patch("server.bootstrap.models.main") as mock_models: - mock_models.return_value = [] - with patch("server.bootstrap.oci.main") as mock_oci: - mock_oci.return_value = [oci1, oci2] - with patch("server.bootstrap.settings.main") as mock_settings: - mock_settings.return_value = [] - - importlib.reload(bootstrap) - - assert len(bootstrap.OCI_OBJECTS) == 2 - - def test_stores_settings_results(self, make_settings): - """Bootstrap module should store settings.main() results.""" - settings1 = make_settings(client="client1") - settings2 = make_settings(client="client2") - - with patch("server.bootstrap.databases.main") as mock_databases: - mock_databases.return_value = [] - with patch("server.bootstrap.models.main") as mock_models: - mock_models.return_value = [] - with patch("server.bootstrap.oci.main") as mock_oci: - mock_oci.return_value = [] - with patch("server.bootstrap.settings.main") as mock_settings: - mock_settings.return_value = [settings1, settings2] - - importlib.reload(bootstrap) - - assert len(bootstrap.SETTINGS_OBJECTS) == 2 - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured in bootstrap module.""" - assert hasattr(bootstrap, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert bootstrap.logger.name == "bootstrap" diff --git a/test/unit/server/bootstrap/test_bootstrap_configfile.py b/test/unit/server/bootstrap/test_bootstrap_configfile.py deleted file mode 100644 index 12c505ca..00000000 --- a/test/unit/server/bootstrap/test_bootstrap_configfile.py +++ /dev/null @@ -1,229 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/bootstrap/configfile.py -Tests for ConfigStore class and config_file_path function. -""" - -# pylint: disable=redefined-outer-name protected-access too-few-public-methods - -import json -import os -import tempfile -from pathlib import Path -from threading import Thread, Barrier - -import pytest - -from server.bootstrap import configfile -from server.bootstrap.configfile import config_file_path - - -class TestConfigStore: - """Tests for the ConfigStore class.""" - - def test_load_from_file_success(self, reset_config_store, temp_config_file, make_settings): - """ConfigStore should load configuration from a valid JSON file.""" - settings = make_settings(client="test_client") - config_path = temp_config_file(client_settings=settings) - - try: - reset_config_store.load_from_file(config_path) - config = reset_config_store.get() - - assert config is not None - assert config.client_settings.client == "test_client" - finally: - os.unlink(config_path) - - def test_load_from_file_nonexistent_file(self, reset_config_store): - """ConfigStore should handle nonexistent files gracefully.""" - nonexistent_path = Path("/nonexistent/path/config.json") - - reset_config_store.load_from_file(nonexistent_path) - config = reset_config_store.get() - - assert config is None - - def test_load_from_file_wrong_extension_warns(self, reset_config_store, caplog): - """ConfigStore should warn when file has wrong extension.""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, encoding="utf-8") as temp_file: - # Need valid client_settings with required 'client' field - json.dump( - { - "client_settings": {"client": "test"}, - "database_configs": [], - "model_configs": [], - "oci_configs": [], - "prompt_configs": [], - }, - temp_file, - ) - temp_path = Path(temp_file.name) - - try: - reset_config_store.load_from_file(temp_path) - assert "should be a .json file" in caplog.text - finally: - os.unlink(temp_path) - - def test_load_from_file_only_loads_once(self, reset_config_store, temp_config_file, make_settings): - """ConfigStore should only load configuration once (singleton pattern).""" - settings1 = make_settings(client="first_client") - settings2 = make_settings(client="second_client") - - config_path1 = temp_config_file(client_settings=settings1) - config_path2 = temp_config_file(client_settings=settings2) - - try: - reset_config_store.load_from_file(config_path1) - reset_config_store.load_from_file(config_path2) # Should be ignored - - config = reset_config_store.get() - assert config.client_settings.client == "first_client" - finally: - os.unlink(config_path1) - os.unlink(config_path2) - - def test_load_from_file_thread_safety(self, reset_config_store, temp_config_file, make_settings): - """ConfigStore should handle concurrent loading safely.""" - settings = make_settings(client="thread_test") - config_path = temp_config_file(client_settings=settings) - - num_threads = 5 - barrier = Barrier(num_threads) - results = [] - - def load_config(): - barrier.wait() # Synchronize threads - reset_config_store.load_from_file(config_path) - results.append(reset_config_store.get()) - - try: - threads = [Thread(target=load_config) for _ in range(num_threads)] - for t in threads: - t.start() - for t in threads: - t.join() - - # All threads should see the same config - assert len(results) == num_threads - assert all(r is not None for r in results) - assert all(r.client_settings.client == "thread_test" for r in results) - finally: - os.unlink(config_path) - - def test_load_from_file_with_database_configs( - self, reset_config_store, temp_config_file, make_settings, make_database - ): - """ConfigStore should load database configurations.""" - settings = make_settings() - db = make_database(name="TEST_DB", user="admin") - config_path = temp_config_file(client_settings=settings, database_configs=[db]) - - try: - reset_config_store.load_from_file(config_path) - config = reset_config_store.get() - - assert config is not None - assert len(config.database_configs) == 1 - assert config.database_configs[0].name == "TEST_DB" - assert config.database_configs[0].user == "admin" - finally: - os.unlink(config_path) - - def test_load_from_file_with_model_configs(self, reset_config_store, temp_config_file, make_settings, make_model): - """ConfigStore should load model configurations.""" - settings = make_settings() - model = make_model(model_id="test-model", provider="openai") - config_path = temp_config_file(client_settings=settings, model_configs=[model]) - - try: - reset_config_store.load_from_file(config_path) - config = reset_config_store.get() - - assert config is not None - assert len(config.model_configs) == 1 - assert config.model_configs[0].id == "test-model" - finally: - os.unlink(config_path) - - def test_load_from_file_with_oci_configs( - self, reset_config_store, temp_config_file, make_settings, make_oci_config - ): - """ConfigStore should load OCI configurations.""" - settings = make_settings() - oci_config = make_oci_config(auth_profile="TEST_PROFILE") - config_path = temp_config_file(client_settings=settings, oci_configs=[oci_config]) - - try: - reset_config_store.load_from_file(config_path) - config = reset_config_store.get() - - assert config is not None - assert len(config.oci_configs) == 1 - assert config.oci_configs[0].auth_profile == "TEST_PROFILE" - finally: - os.unlink(config_path) - - def test_load_from_file_invalid_json(self, reset_config_store): - """ConfigStore should raise error for invalid JSON.""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") as temp_file: - temp_file.write("not valid json {") - temp_path = Path(temp_file.name) - - try: - with pytest.raises(json.JSONDecodeError): - reset_config_store.load_from_file(temp_path) - finally: - os.unlink(temp_path) - - def test_get_returns_none_when_not_loaded(self, reset_config_store): - """ConfigStore.get() should return None when config not loaded.""" - config = reset_config_store.get() - assert config is None - - -class TestConfigFilePath: - """Tests for the config_file_path function.""" - - def test_config_file_path_returns_string(self): - """config_file_path should return a string path.""" - path = config_file_path() - assert isinstance(path, str) - - def test_config_file_path_ends_with_json(self): - """config_file_path should return a .json file path.""" - path = config_file_path() - assert path.endswith(".json") - - def test_config_file_path_contains_etc_directory(self): - """config_file_path should include etc directory.""" - path = config_file_path() - assert "etc" in path - assert "configuration.json" in path - - def test_config_file_path_is_absolute(self): - """config_file_path should return an absolute path.""" - path = config_file_path() - assert os.path.isabs(path) - - def test_config_file_path_parent_is_server_directory(self): - """config_file_path should be relative to server directory.""" - path = config_file_path() - path_obj = Path(path) - # Should be under server/etc/configuration.json - assert path_obj.parent.name == "etc" - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured in configfile module.""" - assert hasattr(configfile, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert configfile.logger.name == "bootstrap.configfile" diff --git a/test/unit/server/bootstrap/test_bootstrap_databases.py b/test/unit/server/bootstrap/test_bootstrap_databases.py deleted file mode 100644 index 3a658689..00000000 --- a/test/unit/server/bootstrap/test_bootstrap_databases.py +++ /dev/null @@ -1,219 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/bootstrap/databases.py -Tests for database bootstrap functionality. -""" - -# pylint: disable=redefined-outer-name protected-access too-few-public-methods - -import os - -from test.shared_fixtures import ( - assert_database_list_valid, - assert_has_default_database, - get_database_by_name, -) - -import pytest - -from server.bootstrap import databases as databases_module - - -@pytest.mark.usefixtures("reset_config_store", "clean_env") -class TestDatabasesMain: - """Tests for the databases.main() function.""" - - def test_main_returns_list_of_databases(self): - """main() should return a list of Database objects.""" - result = databases_module.main() - assert_database_list_valid(result) - - def test_main_creates_default_database_when_no_config(self): - """main() should create DEFAULT database when no config is loaded.""" - result = databases_module.main() - assert_has_default_database(result) - - def test_main_uses_env_vars_for_default_database(self): - """main() should use environment variables for DEFAULT database.""" - os.environ["DB_USERNAME"] = "env_user" - os.environ["DB_PASSWORD"] = "env_password" - os.environ["DB_DSN"] = "env_dsn:1521/ENVPDB" - os.environ["TNS_ADMIN"] = "/env/tns_admin" - - try: - db_list = databases_module.main() - default_entry = get_database_by_name(db_list, "DEFAULT") - assert default_entry.user == "env_user" - assert default_entry.password == "env_password" - assert default_entry.dsn == "env_dsn:1521/ENVPDB" - assert default_entry.config_dir == "/env/tns_admin" - finally: - del os.environ["DB_USERNAME"] - del os.environ["DB_PASSWORD"] - del os.environ["DB_DSN"] - del os.environ["TNS_ADMIN"] - - def test_main_sets_wallet_location_when_wallet_password_present(self): - """main() should set wallet_location when wallet_password is provided.""" - os.environ["DB_WALLET_PASSWORD"] = "wallet_pass" - os.environ["TNS_ADMIN"] = "/wallet/path" - - try: - result = databases_module.main() - default_db = get_database_by_name(result, "DEFAULT") - assert default_db.wallet_password == "wallet_pass" - assert default_db.wallet_location == "/wallet/path" - finally: - del os.environ["DB_WALLET_PASSWORD"] - del os.environ["TNS_ADMIN"] - - def test_main_with_config_file_databases( - self, reset_config_store, temp_config_file, make_settings, make_database - ): - """main() should load databases from config file.""" - settings = make_settings() - db1 = make_database(name="CONFIG_DB1", user="config_user1") - db2 = make_database(name="CONFIG_DB2", user="config_user2") - config_path = temp_config_file(client_settings=settings, database_configs=[db1, db2]) - - try: - reset_config_store.load_from_file(config_path) - integration_result = databases_module.main() - - db_names = [db.name for db in integration_result] - assert "CONFIG_DB1" in db_names - assert "CONFIG_DB2" in db_names - finally: - os.unlink(config_path) - - def test_main_overrides_default_from_config_with_env_vars( - self, reset_config_store, temp_config_file, make_settings, make_database - ): - """main() should override DEFAULT database from config with env vars.""" - test_settings = make_settings() - test_db = make_database(name="DEFAULT", user="config_user", password="config_pass", dsn="config_dsn") - cfg_path = temp_config_file(client_settings=test_settings, database_configs=[test_db]) - - os.environ["DB_USERNAME"] = "env_user" - os.environ["DB_PASSWORD"] = "env_password" - - try: - reset_config_store.load_from_file(cfg_path) - db_list = databases_module.main() - default_entry = get_database_by_name(db_list, "DEFAULT") - assert default_entry.user == "env_user" - assert default_entry.password == "env_password" - assert default_entry.dsn == "config_dsn" # DSN not in env, keep config value - finally: - os.unlink(cfg_path) - del os.environ["DB_USERNAME"] - del os.environ["DB_PASSWORD"] - - def test_main_raises_on_duplicate_database_names( - self, reset_config_store, temp_config_file, make_settings, make_database - ): - """main() should raise ValueError for duplicate database names.""" - settings = make_settings() - db1 = make_database(name="DUP_DB", user="user1") - db2 = make_database(name="dup_db", user="user2") # Case-insensitive duplicate - config_path = temp_config_file(client_settings=settings, database_configs=[db1, db2]) - - try: - reset_config_store.load_from_file(config_path) - - with pytest.raises(ValueError, match="Duplicate database name"): - databases_module.main() - finally: - os.unlink(config_path) - - def test_main_creates_default_when_not_in_config( - self, reset_config_store, temp_config_file, make_settings, make_database - ): - """main() should create DEFAULT database from env when not in config.""" - test_settings = make_settings() - other_db = make_database(name="OTHER_DB", user="other_user") - cfg_path = temp_config_file(client_settings=test_settings, database_configs=[other_db]) - - os.environ["DB_USERNAME"] = "default_env_user" - - try: - reset_config_store.load_from_file(cfg_path) - db_list = databases_module.main() - assert_has_default_database(db_list) - assert "OTHER_DB" in [d.name for d in db_list] - default_entry = get_database_by_name(db_list, "DEFAULT") - assert default_entry.user == "default_env_user" - finally: - os.unlink(cfg_path) - del os.environ["DB_USERNAME"] - - def test_main_handles_case_insensitive_default_name( - self, reset_config_store, temp_config_file, make_settings, make_database - ): - """main() should handle DEFAULT name case-insensitively.""" - settings = make_settings() - db = make_database(name="default", user="config_user") # lowercase - config_path = temp_config_file(client_settings=settings, database_configs=[db]) - - os.environ["DB_USERNAME"] = "env_user" - - try: - reset_config_store.load_from_file(config_path) - result = databases_module.main() - - # Should find and update the lowercase "default" - default_db = next(db for db in result if db.name.upper() == "DEFAULT") - assert default_db.user == "env_user" - finally: - os.unlink(config_path) - del os.environ["DB_USERNAME"] - - def test_main_preserves_non_default_databases_unchanged( - self, reset_config_store, temp_config_file, make_settings, make_database - ): - """main() should not modify non-DEFAULT databases.""" - test_settings = make_settings() - custom_db_config = make_database(name="CUSTOM_DB", user="custom_user", password="custom_pass") - cfg_path = temp_config_file(client_settings=test_settings, database_configs=[custom_db_config]) - - os.environ["DB_USERNAME"] = "should_not_apply" - - try: - reset_config_store.load_from_file(cfg_path) - db_list = databases_module.main() - custom_entry = get_database_by_name(db_list, "CUSTOM_DB") - assert custom_entry.user == "custom_user" - assert custom_entry.password == "custom_pass" - finally: - os.unlink(cfg_path) - del os.environ["DB_USERNAME"] - - def test_main_default_config_dir_fallback(self): - """main() should use 'tns_admin' as default config_dir when not specified.""" - result = databases_module.main() - default_db = get_database_by_name(result, "DEFAULT") - assert default_db.config_dir == "tns_admin" - - -@pytest.mark.usefixtures("reset_config_store", "clean_env") -class TestDatabasesMainAsScript: - """Tests for running databases module as script.""" - - def test_main_callable_directly(self): - """main() should be callable when running as script.""" - result = databases_module.main() - assert result is not None - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured in databases module.""" - assert hasattr(databases_module, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert databases_module.logger.name == "bootstrap.databases" diff --git a/test/unit/server/bootstrap/test_bootstrap_models.py b/test/unit/server/bootstrap/test_bootstrap_models.py deleted file mode 100644 index 8f8a8b1d..00000000 --- a/test/unit/server/bootstrap/test_bootstrap_models.py +++ /dev/null @@ -1,413 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/bootstrap/models.py -Tests for model bootstrap functionality. -""" - -# pylint: disable=redefined-outer-name protected-access too-few-public-methods - -import os -from unittest.mock import patch - -from test.shared_fixtures import assert_model_list_valid, get_model_by_id - -import pytest - -from server.bootstrap import models as models_module - - -@pytest.mark.usefixtures("reset_config_store", "clean_env", "mock_is_url_accessible") -class TestModelsMain: - """Tests for the models.main() function.""" - - def test_main_returns_list_of_models(self): - """main() should return a list of Model objects.""" - result = models_module.main() - assert_model_list_valid(result) - - def test_main_includes_base_models(self): - """main() should include base model configurations.""" - result = models_module.main() - - model_ids = [m.id for m in result] - # Should include at least some base models - assert "gpt-4o-mini" in model_ids - assert "command-r" in model_ids - - def test_main_enables_models_with_api_keys(self): - """main() should enable models when API keys are present.""" - os.environ["OPENAI_API_KEY"] = "test-openai-key" - - try: - model_list = models_module.main() - gpt_model = get_model_by_id(model_list, "gpt-4o-mini") - assert gpt_model.enabled is True - assert gpt_model.api_key == "test-openai-key" - finally: - del os.environ["OPENAI_API_KEY"] - - def test_main_disables_models_without_api_keys(self): - """main() should disable models when API keys are not present.""" - model_list = models_module.main() - gpt_model = get_model_by_id(model_list, "gpt-4o-mini") - assert gpt_model.enabled is False - - @pytest.mark.usefixtures("reset_config_store", "clean_env") - def test_main_checks_url_accessibility(self): - """main() should check URL accessibility for enabled models.""" - os.environ["OPENAI_API_KEY"] = "test-key" - - with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: - mock_accessible.return_value = (False, "Connection refused") - - try: - result = models_module.main() - openai_model = get_model_by_id(result, "gpt-4o-mini") - assert openai_model.enabled is False # Model disabled if URL not accessible - mock_accessible.assert_called() - finally: - del os.environ["OPENAI_API_KEY"] - - @pytest.mark.usefixtures("reset_config_store", "clean_env") - def test_main_caches_url_accessibility_results(self): - """main() should cache URL accessibility results for same URLs.""" - os.environ["OPENAI_API_KEY"] = "test-key" - os.environ["COHERE_API_KEY"] = "test-key" - - with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: - mock_accessible.return_value = (True, "OK") - - try: - models_module.main() - - # Multiple models share the same base URL, should only check once per URL - call_urls = [call[0][0] for call in mock_accessible.call_args_list] - # Should not have duplicate URL checks - assert len(call_urls) == len(set(call_urls)) - finally: - del os.environ["OPENAI_API_KEY"] - del os.environ["COHERE_API_KEY"] - - -@pytest.mark.usefixtures("clean_env") -class TestGetBaseModelsList: - """Tests for the _get_base_models_list function.""" - - def test_returns_list_of_dicts(self): - """_get_base_models_list should return a list of dictionaries.""" - result = models_module._get_base_models_list() - - assert isinstance(result, list) - assert all(isinstance(m, dict) for m in result) - - def test_includes_required_fields(self): - """_get_base_models_list should include required fields for each model.""" - result = models_module._get_base_models_list() - - for model in result: - assert "id" in model - assert "type" in model - assert "provider" in model - assert "api_base" in model - - def test_includes_ll_and_embed_models(self): - """_get_base_models_list should include both LLM and embedding models.""" - result = models_module._get_base_models_list() - - types = {m["type"] for m in result} - assert "ll" in types - assert "embed" in types - - -class TestCheckForDuplicates: - """Tests for the _check_for_duplicates function.""" - - def test_no_error_for_unique_models(self): - """_check_for_duplicates should not raise for unique models.""" - models_list = [ - {"id": "model1", "provider": "openai"}, - {"id": "model2", "provider": "openai"}, - {"id": "model1", "provider": "cohere"}, # Same ID, different provider - ] - - # Should not raise - models_module._check_for_duplicates(models_list) - - def test_raises_for_duplicate_models(self): - """_check_for_duplicates should raise ValueError for duplicates.""" - models_list = [ - {"id": "model1", "provider": "openai"}, - {"id": "model1", "provider": "openai"}, # Duplicate - ] - - with pytest.raises(ValueError, match="already exists"): - models_module._check_for_duplicates(models_list) - - -class TestValuesDiffer: - """Tests for the _values_differ function.""" - - def test_bool_comparison(self): - """_values_differ should handle boolean comparisons.""" - assert models_module._values_differ(True, False) is True - assert models_module._values_differ(True, True) is False - assert models_module._values_differ(False, False) is False - - def test_numeric_comparison(self): - """_values_differ should handle numeric comparisons.""" - assert models_module._values_differ(1, 2) is True - assert models_module._values_differ(1.0, 1.0) is False - assert models_module._values_differ(1, 1.0) is False - # Small float differences should be considered equal - assert models_module._values_differ(1.0, 1.0 + 1e-9) is False - assert models_module._values_differ(1.0, 1.1) is True - - def test_string_comparison(self): - """_values_differ should handle string comparisons with strip.""" - assert models_module._values_differ("test", "test") is False - assert models_module._values_differ(" test ", "test") is False - assert models_module._values_differ("test", "other") is True - - def test_general_comparison(self): - """_values_differ should handle general equality comparison.""" - assert models_module._values_differ([1, 2], [1, 2]) is False - assert models_module._values_differ([1, 2], [1, 3]) is True - assert models_module._values_differ(None, None) is False - assert models_module._values_differ(None, "value") is True - - -@pytest.mark.usefixtures("reset_config_store") -class TestMergeWithConfigStore: - """Tests for the _merge_with_config_store function.""" - - def test_returns_unchanged_when_no_config(self): - """_merge_with_config_store should return unchanged list when no config.""" - models_list = [{"id": "model1", "provider": "openai", "enabled": False}] - - result = models_module._merge_with_config_store(models_list) - - assert result == models_list - - def test_merges_config_store_models( - self, reset_config_store, temp_config_file, make_settings, make_model - ): - """_merge_with_config_store should merge models from ConfigStore.""" - settings = make_settings() - config_model = make_model(model_id="config-model", provider="custom") - config_path = temp_config_file(client_settings=settings, model_configs=[config_model]) - - models_list = [{"id": "existing", "provider": "openai", "enabled": False}] - - try: - reset_config_store.load_from_file(config_path) - result = models_module._merge_with_config_store(models_list) - - model_keys = [(m["provider"], m["id"]) for m in result] - assert ("custom", "config-model") in model_keys - assert ("openai", "existing") in model_keys - finally: - os.unlink(config_path) - - def test_overrides_existing_model_values( - self, reset_config_store, temp_config_file, make_settings, make_model - ): - """_merge_with_config_store should override existing model values.""" - settings = make_settings() - config_model = make_model(model_id="existing", provider="openai", enabled=True) - config_path = temp_config_file(client_settings=settings, model_configs=[config_model]) - - models_list = [ - {"id": "existing", "provider": "openai", "enabled": False, "api_base": "https://api.openai.com/v1"} - ] - - try: - reset_config_store.load_from_file(config_path) - result = models_module._merge_with_config_store(models_list) - - merged_model = next(m for m in result if m["id"] == "existing") - assert merged_model["enabled"] is True - finally: - os.unlink(config_path) - - -class ModelDict(dict): - """Dict subclass that also supports attribute access for 'id'. - - The _update_env_var function in models.py uses both dict-style (.get(), []) - and attribute-style (.id) access, so tests need objects that support both. - """ - - def __getattr__(self, name): - if name in self: - return self[name] - raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") - - -@pytest.mark.usefixtures("clean_env") -class TestApplyEnvVarOverrides: - """Tests for the _apply_env_var_overrides function.""" - - def test_applies_cohere_api_key(self): - """_apply_env_var_overrides should apply COHERE_API_KEY.""" - # Use ModelDict to support both dict and attribute access (needed for model.id) - models_list = [ModelDict({"id": "command-r", "provider": "cohere", "api_key": "original"})] - os.environ["COHERE_API_KEY"] = "env-key" - - try: - models_module._apply_env_var_overrides(models_list) - - assert models_list[0]["api_key"] == "env-key" - finally: - del os.environ["COHERE_API_KEY"] - - def test_applies_ollama_url(self): - """_apply_env_var_overrides should apply ON_PREM_OLLAMA_URL.""" - models_list = [ModelDict({"id": "llama3.1", "provider": "ollama", "api_base": "http://localhost:11434"})] - os.environ["ON_PREM_OLLAMA_URL"] = "http://custom:11434" - - try: - models_module._apply_env_var_overrides(models_list) - - assert models_list[0]["api_base"] == "http://custom:11434" - finally: - del os.environ["ON_PREM_OLLAMA_URL"] - - def test_does_not_apply_to_wrong_provider(self): - """_apply_env_var_overrides should not apply overrides to wrong provider.""" - models_list = [ModelDict({"id": "gpt-4o-mini", "provider": "openai", "api_key": "original"})] - os.environ["COHERE_API_KEY"] = "env-key" - - try: - models_module._apply_env_var_overrides(models_list) - - assert models_list[0]["api_key"] == "original" - finally: - del os.environ["COHERE_API_KEY"] - - -@pytest.mark.usefixtures("clean_env") -class TestUpdateEnvVar: - """Tests for the _update_env_var function. - - Note: _update_env_var uses dict-style access (.get(), []) but also accesses - model.id directly for logging. Use ModelDict for compatibility. - """ - - def test_updates_matching_provider(self): - """_update_env_var should update model when provider matches.""" - model = ModelDict({"id": "gpt-4o-mini", "provider": "openai", "api_key": "old"}) - os.environ["TEST_KEY"] = "new" - - try: - models_module._update_env_var(model, "openai", "api_key", "TEST_KEY") - - assert model["api_key"] == "new" - finally: - del os.environ["TEST_KEY"] - - def test_ignores_non_matching_provider(self): - """_update_env_var should not update when provider doesn't match.""" - model = ModelDict({"id": "command-r", "provider": "cohere", "api_key": "old"}) - os.environ["TEST_KEY"] = "new" - - try: - models_module._update_env_var(model, "openai", "api_key", "TEST_KEY") - - assert model["api_key"] == "old" - finally: - del os.environ["TEST_KEY"] - - def test_ignores_when_env_var_not_set(self): - """_update_env_var should not update when env var is not set.""" - model = ModelDict({"id": "gpt-4o-mini", "provider": "openai", "api_key": "old"}) - - models_module._update_env_var(model, "openai", "api_key", "NONEXISTENT_VAR") - - assert model["api_key"] == "old" - - def test_ignores_when_value_unchanged(self): - """_update_env_var should not update when value is the same.""" - model = ModelDict({"id": "gpt-4o-mini", "provider": "openai", "api_key": "same"}) - os.environ["TEST_KEY"] = "same" - - try: - models_module._update_env_var(model, "openai", "api_key", "TEST_KEY") - - assert model["api_key"] == "same" - finally: - del os.environ["TEST_KEY"] - - -@pytest.mark.usefixtures("clean_env") -class TestCheckUrlAccessibility: - """Tests for the _check_url_accessibility function.""" - - def test_disables_inaccessible_urls(self): - """_check_url_accessibility should disable models with inaccessible URLs.""" - models_list = [{"id": "test", "api_base": "http://localhost:1234", "enabled": True}] - - with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: - mock_accessible.return_value = (False, "Connection refused") - - models_module._check_url_accessibility(models_list) - - assert models_list[0]["enabled"] is False - - def test_keeps_accessible_urls_enabled(self): - """_check_url_accessibility should keep models with accessible URLs enabled.""" - models_list = [{"id": "test", "api_base": "http://localhost:1234", "enabled": True}] - - with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: - mock_accessible.return_value = (True, "OK") - - models_module._check_url_accessibility(models_list) - - assert models_list[0]["enabled"] is True - - def test_skips_disabled_models(self): - """_check_url_accessibility should skip models that are already disabled.""" - models_list = [{"id": "test", "api_base": "http://localhost:1234", "enabled": False}] - - with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: - models_module._check_url_accessibility(models_list) - - mock_accessible.assert_not_called() - - def test_caches_url_results(self): - """_check_url_accessibility should cache results for the same URL.""" - models_list = [ - {"id": "test1", "api_base": "http://localhost:1234", "enabled": True}, - {"id": "test2", "api_base": "http://localhost:1234", "enabled": True}, - ] - - with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: - mock_accessible.return_value = (True, "OK") - - models_module._check_url_accessibility(models_list) - - # Should only be called once for the shared URL - assert mock_accessible.call_count == 1 - - -@pytest.mark.usefixtures("reset_config_store", "clean_env", "mock_is_url_accessible") -class TestModelsMainAsScript: - """Tests for running models module as script.""" - - def test_main_callable_directly(self): - """main() should be callable when running as script.""" - result = models_module.main() - assert result is not None - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured in models module.""" - assert hasattr(models_module, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert models_module.logger.name == "bootstrap.models" diff --git a/test/unit/server/bootstrap/test_bootstrap_oci.py b/test/unit/server/bootstrap/test_bootstrap_oci.py deleted file mode 100644 index 89242c8e..00000000 --- a/test/unit/server/bootstrap/test_bootstrap_oci.py +++ /dev/null @@ -1,329 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/bootstrap/oci.py -Tests for OCI bootstrap functionality. -""" - -# pylint: disable=redefined-outer-name protected-access too-few-public-methods - -import os -from unittest.mock import patch, MagicMock - -import pytest -import oci - -from server.bootstrap import oci as oci_module -from common.schema import OracleCloudSettings - - -@pytest.mark.usefixtures("reset_config_store", "clean_env") -class TestOciMain: - """Tests for the oci.main() function.""" - - def test_main_returns_list_of_oci_settings(self): - """main() should return a list of OracleCloudSettings objects.""" - with patch("oci.config.from_file", side_effect=oci.exceptions.ConfigFileNotFound()): - result = oci_module.main() - - assert isinstance(result, list) - assert all(isinstance(s, OracleCloudSettings) for s in result) - - def test_main_creates_default_profile_when_no_config(self): - """main() should create DEFAULT profile when no OCI config exists.""" - with patch("oci.config.from_file", side_effect=oci.exceptions.ConfigFileNotFound()): - result = oci_module.main() - - profile_names = [s.auth_profile for s in result] - assert oci.config.DEFAULT_PROFILE in profile_names - - def test_main_reads_oci_config_file(self): - """main() should read from OCI config file when it exists.""" - # User OCID must match pattern ^([0-9a-zA-Z-_]+[.:])([0-9a-zA-Z-_]*[.:]){3,}([0-9a-zA-Z-_]+)$ - mock_config_data = { - "tenancy": "ocid1.tenancy.oc1..test123", - "region": "us-phoenix-1", - "user": "ocid1.user.oc1..test123", # Valid OCID pattern - "fingerprint": "test-fingerprint", - "key_file": "/path/to/key.pem", - } - - with patch("configparser.ConfigParser") as mock_parser: - mock_instance = MagicMock() - mock_instance.sections.return_value = [] - mock_parser.return_value = mock_instance - - with patch("oci.config.from_file", return_value=mock_config_data.copy()): - result = oci_module.main() - - assert len(result) >= 1 - default_profile = next((p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE), None) - assert default_profile is not None - - def test_main_applies_env_var_overrides_to_default(self): - """main() should apply environment variable overrides to DEFAULT profile.""" - # User OCID must match pattern ^([0-9a-zA-Z-_]+[.:])([0-9a-zA-Z-_]*[.:]){3,}([0-9a-zA-Z-_]+)$ - os.environ["OCI_CLI_TENANCY"] = "env-tenancy" - os.environ["OCI_CLI_REGION"] = "us-chicago-1" - os.environ["OCI_CLI_USER"] = "ocid1.user.oc1..envuser123" # Valid OCID pattern - - try: - with patch("oci.config.from_file", side_effect=oci.exceptions.ConfigFileNotFound()): - result = oci_module.main() - - default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) - assert default_profile.tenancy == "env-tenancy" - assert default_profile.region == "us-chicago-1" - assert default_profile.user == "ocid1.user.oc1..envuser123" - finally: - del os.environ["OCI_CLI_TENANCY"] - del os.environ["OCI_CLI_REGION"] - del os.environ["OCI_CLI_USER"] - - def test_main_env_overrides_genai_settings(self): - """main() should apply GenAI environment variable overrides.""" - # genai_compartment_id must match OCID pattern - os.environ["OCI_GENAI_COMPARTMENT_ID"] = "ocid1.compartment.oc1..genaitest" - os.environ["OCI_GENAI_REGION"] = "us-chicago-1" - - try: - with patch("oci.config.from_file", side_effect=oci.exceptions.ConfigFileNotFound()): - result = oci_module.main() - - default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) - assert default_profile.genai_compartment_id == "ocid1.compartment.oc1..genaitest" - assert default_profile.genai_region == "us-chicago-1" - finally: - del os.environ["OCI_GENAI_COMPARTMENT_ID"] - del os.environ["OCI_GENAI_REGION"] - - def test_main_security_token_authentication(self): - """main() should set authentication based on security_token_file in profile. - - Note: Due to how profile.update() works, the authentication logic reads the - OLD value of security_token_file before the update completes. If security_token_file - is already set in the profile, authentication becomes 'security_token'. - For env var alone without existing profile value, use OCI_CLI_AUTH instead. - """ - # To get security_token auth, we need OCI_CLI_AUTH explicitly set - # OR we need security_token_file already in the profile before overrides - os.environ["OCI_CLI_SECURITY_TOKEN_FILE"] = "/path/to/token" - os.environ["OCI_CLI_AUTH"] = "security_token" # Must explicitly set - - try: - with patch("oci.config.from_file", side_effect=oci.exceptions.ConfigFileNotFound()): - result = oci_module.main() - - default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) - assert default_profile.authentication == "security_token" - assert default_profile.security_token_file == "/path/to/token" - finally: - del os.environ["OCI_CLI_SECURITY_TOKEN_FILE"] - del os.environ["OCI_CLI_AUTH"] - - def test_main_explicit_auth_env_var(self): - """main() should use OCI_CLI_AUTH env var when specified.""" - os.environ["OCI_CLI_AUTH"] = "instance_principal" - - try: - with patch("oci.config.from_file", side_effect=oci.exceptions.ConfigFileNotFound()): - result = oci_module.main() - - default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) - assert default_profile.authentication == "instance_principal" - finally: - del os.environ["OCI_CLI_AUTH"] - - def test_main_loads_multiple_profiles(self): - """main() should load multiple profiles from OCI config.""" - profiles = ["PROFILE1", "PROFILE2"] - - with patch("configparser.ConfigParser") as mock_parser: - mock_instance = MagicMock() - mock_instance.sections.return_value = profiles - mock_parser.return_value = mock_instance - - def mock_from_file(**kwargs): - profile_name = kwargs.get("profile_name") - # User must be None or valid OCID pattern - return { - "tenancy": f"tenancy-{profile_name}", - "region": "us-ashburn-1", - "fingerprint": "fingerprint", - "key_file": "/path/to/key.pem", - } - - with patch("oci.config.from_file", side_effect=mock_from_file): - result = oci_module.main() - - profile_names = [p.auth_profile for p in result] - assert "PROFILE1" in profile_names - assert "PROFILE2" in profile_names - - def test_main_handles_invalid_key_file_path(self): - """main() should skip profiles with invalid key file paths.""" - profiles = ["VALID", "INVALID"] - - with patch("configparser.ConfigParser") as mock_parser: - mock_instance = MagicMock() - mock_instance.sections.return_value = profiles - mock_parser.return_value = mock_instance - - def mock_from_file(**kwargs): - profile_name = kwargs.get("profile_name") - if profile_name == "INVALID": - raise oci.exceptions.InvalidKeyFilePath("Invalid key file") - # User must be None or valid OCID pattern - return { - "tenancy": "tenancy", - "region": "us-ashburn-1", - "fingerprint": "fingerprint", - "key_file": "/path/to/key.pem", - } - - with patch("oci.config.from_file", side_effect=mock_from_file): - result = oci_module.main() - - profile_names = [p.auth_profile for p in result] - assert "VALID" in profile_names - # INVALID should be skipped, DEFAULT should be created - - def test_main_merges_config_store_oci_configs( - self, reset_config_store, temp_config_file, make_settings, make_oci_config - ): - """main() should merge OCI configs from ConfigStore.""" - settings = make_settings() - oci_config = make_oci_config(auth_profile="CONFIG_PROFILE", tenancy="config-tenancy") - config_path = temp_config_file(client_settings=settings, oci_configs=[oci_config]) - - try: - with patch("oci.config.from_file", side_effect=oci.exceptions.ConfigFileNotFound()): - reset_config_store.load_from_file(config_path) - result = oci_module.main() - - profile_names = [p.auth_profile for p in result] - assert "CONFIG_PROFILE" in profile_names - - config_profile = next(p for p in result if p.auth_profile == "CONFIG_PROFILE") - assert config_profile.tenancy == "config-tenancy" - finally: - os.unlink(config_path) - - def test_main_config_store_overrides_existing_profile( - self, reset_config_store, temp_config_file, make_settings, make_oci_config - ): - """main() should override existing profiles with ConfigStore configs.""" - settings = make_settings() - oci_config = make_oci_config(auth_profile=oci.config.DEFAULT_PROFILE, tenancy="override-tenancy") - config_path = temp_config_file(client_settings=settings, oci_configs=[oci_config]) - - # User must be None or valid OCID pattern - mock_file_config = { - "tenancy": "file-tenancy", - "region": "us-ashburn-1", - "fingerprint": "fingerprint", - "key_file": "/path/to/key.pem", - } - - try: - with patch("configparser.ConfigParser") as mock_parser: - mock_instance = MagicMock() - mock_instance.sections.return_value = [] - mock_parser.return_value = mock_instance - - with patch("oci.config.from_file", return_value=mock_file_config.copy()): - reset_config_store.load_from_file(config_path) - result = oci_module.main() - - default_profile = next(p for p in result if p.auth_profile == oci.config.DEFAULT_PROFILE) - # ConfigStore should override file config - assert default_profile.tenancy == "override-tenancy" - finally: - os.unlink(config_path) - - def test_main_uses_custom_config_file_path(self): - """main() should use OCI_CLI_CONFIG_FILE env var for config path.""" - custom_path = "/custom/oci/config" - os.environ["OCI_CLI_CONFIG_FILE"] = custom_path - - try: - with patch("configparser.ConfigParser") as mock_parser: - mock_instance = MagicMock() - mock_instance.sections.return_value = [] - mock_parser.return_value = mock_instance - - with patch("oci.config.from_file", side_effect=oci.exceptions.ConfigFileNotFound()): - result = oci_module.main() - - # The expanded path should be used - assert len(result) >= 1 - finally: - del os.environ["OCI_CLI_CONFIG_FILE"] - - -@pytest.mark.usefixtures("clean_env") -class TestApplyEnvOverrides: - """Tests for the _apply_env_overrides_to_default_profile function.""" - - def test_override_function_modifies_default_profile(self): - """_apply_env_overrides_to_default_profile should modify DEFAULT profile.""" - config = [{"auth_profile": oci.config.DEFAULT_PROFILE, "tenancy": "original"}] - - os.environ["OCI_CLI_TENANCY"] = "overridden" - - try: - oci_module._apply_env_overrides_to_default_profile(config) - - assert config[0]["tenancy"] == "overridden" - finally: - del os.environ["OCI_CLI_TENANCY"] - - def test_override_function_ignores_non_default_profiles(self): - """_apply_env_overrides_to_default_profile should not modify non-DEFAULT profiles.""" - config = [{"auth_profile": "CUSTOM", "tenancy": "original"}] - - os.environ["OCI_CLI_TENANCY"] = "overridden" - - try: - oci_module._apply_env_overrides_to_default_profile(config) - - assert config[0]["tenancy"] == "original" - finally: - del os.environ["OCI_CLI_TENANCY"] - - def test_override_logs_changes(self, caplog): - """_apply_env_overrides_to_default_profile should log overrides.""" - config = [{"auth_profile": oci.config.DEFAULT_PROFILE, "tenancy": "original"}] - - os.environ["OCI_CLI_TENANCY"] = "new-tenancy" - - try: - oci_module._apply_env_overrides_to_default_profile(config) - - assert "Environment variable overrides" in caplog.text or "new-tenancy" in str(config) - finally: - del os.environ["OCI_CLI_TENANCY"] - - -@pytest.mark.usefixtures("reset_config_store", "clean_env") -class TestOciMainAsScript: - """Tests for running OCI module as script.""" - - def test_main_callable_directly(self): - """main() should be callable when running as script.""" - with patch("oci.config.from_file", side_effect=oci.exceptions.ConfigFileNotFound()): - result = oci_module.main() - assert result is not None - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured in oci module.""" - assert hasattr(oci_module, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert oci_module.logger.name == "bootstrap.oci" diff --git a/test/unit/server/bootstrap/test_bootstrap_settings.py b/test/unit/server/bootstrap/test_bootstrap_settings.py deleted file mode 100644 index 5bac59b8..00000000 --- a/test/unit/server/bootstrap/test_bootstrap_settings.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for server/bootstrap/settings.py -Tests for settings bootstrap functionality. -""" - -# pylint: disable=redefined-outer-name protected-access too-few-public-methods - -import os -from unittest.mock import patch, MagicMock - -import pytest - -from server.bootstrap import settings as settings_module -from common.schema import Settings - - -@pytest.mark.usefixtures("reset_config_store") -class TestSettingsMain: - """Tests for the settings.main() function.""" - - def test_main_returns_list_of_settings(self): - """main() should return a list of Settings objects.""" - result = settings_module.main() - - assert isinstance(result, list) - assert all(isinstance(s, Settings) for s in result) - - def test_main_creates_default_and_server_clients(self): - """main() should create settings for 'default' and 'server' clients.""" - result = settings_module.main() - - client_names = [s.client for s in result] - assert "default" in client_names - assert "server" in client_names - assert len(result) == 2 - - def test_main_without_config_uses_default_settings(self): - """main() should use default Settings when no config is loaded.""" - result = settings_module.main() - - # Both should have default Settings values - for s in result: - assert isinstance(s, Settings) - assert s.client in ["default", "server"] - - def test_main_with_config_uses_config_settings(self, reset_config_store, temp_config_file, make_settings): - """main() should use config file settings when available.""" - # Create settings with custom values - custom_settings = make_settings(client="config_client") - custom_settings.ll_model.temperature = 0.9 - custom_settings.ll_model.max_tokens = 8192 - - config_path = temp_config_file(client_settings=custom_settings) - - try: - reset_config_store.load_from_file(config_path) - result = settings_module.main() - - # Both clients should inherit from config settings - for s in result: - assert s.ll_model.temperature == 0.9 - assert s.ll_model.max_tokens == 8192 - # Client name should be overridden to default/server - assert s.client in ["default", "server"] - finally: - os.unlink(config_path) - - def test_main_preserves_client_names_from_base_list(self, reset_config_store, temp_config_file, make_settings): - """main() should override client field from config with base client names.""" - custom_settings = make_settings(client="original_name") - config_path = temp_config_file(client_settings=custom_settings) - - try: - reset_config_store.load_from_file(config_path) - result = settings_module.main() - - # Client names should be "default" and "server", not "original_name" - client_names = [s.client for s in result] - assert "original_name" not in client_names - assert "default" in client_names - assert "server" in client_names - finally: - os.unlink(config_path) - - def test_main_with_config_but_no_client_settings(self, reset_config_store): - """main() should use default Settings when config has no client_settings.""" - mock_config = MagicMock() - mock_config.client_settings = None - - with patch.object(reset_config_store, "get", return_value=mock_config): - result = settings_module.main() - - assert len(result) == 2 - assert all(isinstance(s, Settings) for s in result) - - def test_main_creates_copies_with_different_clients(self, reset_config_store, temp_config_file, make_settings): - """main() should create separate Settings objects with unique client names. - - Note: Pydantic's model_copy() creates shallow copies by default, - so nested objects (like ll_model) may be shared. However, the top-level - Settings objects should be distinct with their own 'client' values. - """ - custom_settings = make_settings(client="config_client") - config_path = temp_config_file(client_settings=custom_settings) - - try: - reset_config_store.load_from_file(config_path) - result = settings_module.main() - - # The Settings objects themselves should be distinct - assert result[0] is not result[1] - # And have different client names - assert result[0].client != result[1].client - assert result[0].client in ["default", "server"] - assert result[1].client in ["default", "server"] - finally: - os.unlink(config_path) - - -@pytest.mark.usefixtures("reset_config_store") -class TestSettingsMainAsScript: - """Tests for running settings module as script.""" - - def test_main_callable_directly(self): - """main() should be callable when running as script.""" - # This tests the if __name__ == "__main__" block indirectly - result = settings_module.main() - assert result is not None - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured in settings module.""" - assert hasattr(settings_module, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert settings_module.logger.name == "bootstrap.settings" diff --git a/tests/client/integration/content/tools/tabs/test_split_embed.py b/tests/client/integration/content/tools/tabs/test_split_embed.py index d20998ba..7fdd7b47 100644 --- a/tests/client/integration/content/tools/tabs/test_split_embed.py +++ b/tests/client/integration/content/tools/tabs/test_split_embed.py @@ -621,18 +621,18 @@ def test_create_new_vs_toggle_shown_when_vector_stores_exist(self, app_server, a # Ensure database has vector stores if at.session_state.database_configs: - # Find matching model ID for the vector store (format: provider/id) - model_key = None + # Find matching model ID for the vector store + model_id = None for model in at.session_state.model_configs: if model["type"] == "embed" and model.get("enabled"): - model_key = f"{model.get('provider')}/{model['id']}" + model_id = model["id"] break - if model_key: + if model_id: at.session_state.database_configs[0]["vector_stores"] = [ { "alias": "existing_vs", - "model": model_key, + "model": model_id, "vector_store": "VECTOR_STORE_TABLE", "chunk_size": 500, "chunk_overlap": 50, diff --git a/tests/client/integration/utils/test_st_common.py b/tests/client/integration/utils/test_st_common.py index 0103607a..164ecb13 100644 --- a/tests/client/integration/utils/test_st_common.py +++ b/tests/client/integration/utils/test_st_common.py @@ -2,8 +2,329 @@ """ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Integration tests for st_common utilities. -Vector store selection tests have been moved to test_vs_options.py """ # spell-checker: disable + +from unittest.mock import patch + +import pandas as pd +import pytest +import streamlit as st +from streamlit import session_state as state + +from client.utils import st_common + + +############################################################################# +# Fixtures +############################################################################# +@pytest.fixture +def vector_store_state(sample_vector_store_data): + """Setup common vector store state for tests using shared test data""" + # Setup initial state with vector search settings + state.client_settings = { + "vector_search": { + "enabled": True, + **sample_vector_store_data, + "top_k": 10, + "search_type": "Similarity", + "score_threshold": 0.5, + "fetch_k": 20, + "lambda_mult": 0.5, + }, + "database": {"alias": "DEFAULT"}, + "ll_model": {"model": "gpt-4", "temperature": 0.8}, + } + + # Set widget states to simulate user selections + state.selected_vector_search_model = sample_vector_store_data["model"] + state.selected_vector_search_chunk_size = sample_vector_store_data["chunk_size"] + state.selected_vector_search_chunk_overlap = sample_vector_store_data["chunk_overlap"] + state.selected_vector_search_distance_metric = sample_vector_store_data["distance_metric"] + state.selected_vector_search_alias = sample_vector_store_data["alias"] + state.selected_vector_search_index_type = sample_vector_store_data["index_type"] + + yield state + + # Cleanup after test + for key in list(state.keys()): + if key.startswith("selected_vector_search_"): + del state[key] + + +############################################################################# +# Test Vector Store Reset Button Functionality - Integration Tests +############################################################################# +class TestVectorStoreResetButtonIntegration: + """Integration tests for vector store selection Reset button""" + + def test_reset_button_callback_execution(self, app_server, vector_store_state, sample_vector_store_data): + """Test that the Reset button callback is properly executed when clicked""" + assert app_server is not None + assert vector_store_state is not None + + reset_callback_executed = False + + def mock_button(label, **kwargs): + nonlocal reset_callback_executed + if "Reset" in label and "on_click" in kwargs: + # Execute the callback to simulate button click + kwargs["on_click"]() + reset_callback_executed = True + return True + + with ( + patch.object(st.sidebar, "subheader"), + patch.object(st.sidebar, "button", side_effect=mock_button), + patch.object(st.sidebar, "selectbox"), + patch.object(st, "info"), + ): + # Create test dataframe using shared test data + vs_df = pd.DataFrame([sample_vector_store_data]) + + # Mock enabled models + with patch.object(st_common, "enabled_models_lookup") as mock_models: + mock_models.return_value = {"openai/text-embed-3": {"id": "text-embed-3"}} + + # Call the function + st_common.render_vector_store_selection(vs_df) + + # Verify reset callback was executed + assert reset_callback_executed + + # Verify all widget states are cleared + assert state.selected_vector_search_model == "" + assert state.selected_vector_search_chunk_size == "" + assert state.selected_vector_search_chunk_overlap == "" + assert state.selected_vector_search_distance_metric == "" + assert state.selected_vector_search_alias == "" + assert state.selected_vector_search_index_type == "" + + # Verify client_settings are also cleared + assert state.client_settings["vector_search"]["model"] == "" + assert state.client_settings["vector_search"]["chunk_size"] == "" + assert state.client_settings["vector_search"]["chunk_overlap"] == "" + assert state.client_settings["vector_search"]["distance_metric"] == "" + assert state.client_settings["vector_search"]["vector_store"] == "" + assert state.client_settings["vector_search"]["alias"] == "" + assert state.client_settings["vector_search"]["index_type"] == "" + + def test_reset_preserves_non_vector_store_settings(self, app_server, vector_store_state, sample_vector_store_data): + """Test that Reset only affects vector store fields, not other settings""" + assert app_server is not None + assert vector_store_state is not None + + def mock_button(label, **kwargs): + if "Reset" in label and "on_click" in kwargs: + kwargs["on_click"]() + return True + + with ( + patch.object(st.sidebar, "subheader"), + patch.object(st.sidebar, "button", side_effect=mock_button), + patch.object(st.sidebar, "selectbox"), + patch.object(st, "info"), + ): + vs_df = pd.DataFrame([sample_vector_store_data]) + + with patch.object(st_common, "enabled_models_lookup") as mock_models: + mock_models.return_value = {"openai/text-embed-3": {"id": "text-embed-3"}} + st_common.render_vector_store_selection(vs_df) + + # Vector store fields should be cleared + assert state.client_settings["vector_search"]["model"] == "" + assert state.client_settings["vector_search"]["alias"] == "" + + # Other settings should be preserved + assert state.client_settings["vector_search"]["top_k"] == 10 + assert state.client_settings["vector_search"]["search_type"] == "Similarity" + assert state.client_settings["vector_search"]["score_threshold"] == 0.5 + assert state.client_settings["database"]["alias"] == "DEFAULT" + assert state.client_settings["ll_model"]["model"] == "gpt-4" + assert state.client_settings["ll_model"]["temperature"] == 0.8 + + def test_auto_population_after_reset_single_option(self, app_server, sample_vector_store_data): + """Test that fields with single options are auto-populated after reset""" + assert app_server is not None + + # Setup clean state + state.client_settings = { + "vector_search": { + "enabled": True, + "model": "", # Empty after reset + "chunk_size": "", + "chunk_overlap": "", + "distance_metric": "", + "vector_store": "", + "alias": "", + "index_type": "", + "top_k": 10, + "search_type": "Similarity", + "score_threshold": 0.5, + "fetch_k": 20, + "lambda_mult": 0.5, + }, + "database": {"alias": "DEFAULT"}, + } + + # Clear widget states (simulating post-reset state) + state.selected_vector_search_model = "" + state.selected_vector_search_chunk_size = "" + state.selected_vector_search_chunk_overlap = "" + state.selected_vector_search_distance_metric = "" + state.selected_vector_search_alias = "" + state.selected_vector_search_index_type = "" + + selectbox_calls = [] + + def mock_selectbox(label, options, key, index, disabled=False): + selectbox_calls.append( + {"label": label, "options": options, "key": key, "index": index, "disabled": disabled} + ) + # Return the value at index + return options[index] if 0 <= index < len(options) else "" + + with ( + patch.object(st.sidebar, "subheader"), + patch.object(st.sidebar, "button"), + patch.object(st.sidebar, "selectbox", side_effect=mock_selectbox), + patch.object(st, "info"), + ): + # Create dataframe with single option per field using shared fixture + single_vs = sample_vector_store_data.copy() + single_vs["alias"] = "single_alias" + single_vs["vector_store"] = "single_vs" + vs_df = pd.DataFrame([single_vs]) + + with patch.object(st_common, "enabled_models_lookup") as mock_models: + mock_models.return_value = {"openai/text-embed-3": {"id": "text-embed-3"}} + st_common.render_vector_store_selection(vs_df) + + # Verify auto-population happened for single options + assert state.client_settings["vector_search"]["alias"] == "single_alias" + assert state.client_settings["vector_search"]["model"] == sample_vector_store_data["model"] + assert state.client_settings["vector_search"]["chunk_size"] == sample_vector_store_data["chunk_size"] + assert state.client_settings["vector_search"]["chunk_overlap"] == sample_vector_store_data["chunk_overlap"] + assert ( + state.client_settings["vector_search"]["distance_metric"] + == sample_vector_store_data["distance_metric"] + ) + assert state.client_settings["vector_search"]["index_type"] == sample_vector_store_data["index_type"] + + # Verify widget states were also set + assert state.selected_vector_search_alias == "single_alias" + assert state.selected_vector_search_model == sample_vector_store_data["model"] + + def test_no_auto_population_with_multiple_options( + self, app_server, sample_vector_store_data, sample_vector_store_data_alt + ): + """Test that fields with multiple options are NOT auto-populated after reset""" + assert app_server is not None + + # Setup clean state after reset + state.client_settings = { + "vector_search": { + "enabled": True, + "model": "", + "chunk_size": "", + "chunk_overlap": "", + "distance_metric": "", + "vector_store": "", + "alias": "", + "index_type": "", + "top_k": 10, + "search_type": "Similarity", + "score_threshold": 0.5, + "fetch_k": 20, + "lambda_mult": 0.5, + }, + "database": {"alias": "DEFAULT"}, + } + + # Clear widget states + for key in ["model", "chunk_size", "chunk_overlap", "distance_metric", "alias", "index_type"]: + state[f"selected_vector_search_{key}"] = "" + + with ( + patch.object(st.sidebar, "subheader"), + patch.object(st.sidebar, "button"), + patch.object(st.sidebar, "selectbox", return_value=""), + patch.object(st, "info"), + ): + # Create dataframe with multiple options using shared fixtures + vs1 = sample_vector_store_data.copy() + vs1["alias"] = "alias1" + vs2 = sample_vector_store_data_alt.copy() + vs2["alias"] = "alias2" + vs_df = pd.DataFrame([vs1, vs2]) + + with patch.object(st_common, "enabled_models_lookup") as mock_models: + mock_models.return_value = {"openai/text-embed-3": {"id": "text-embed-3"}} + st_common.render_vector_store_selection(vs_df) + + # With multiple options, fields should remain empty (no auto-population) + assert state.client_settings["vector_search"]["alias"] == "" + assert state.client_settings["vector_search"]["chunk_size"] == "" + assert state.client_settings["vector_search"]["chunk_overlap"] == "" + assert state.client_settings["vector_search"]["distance_metric"] == "" + assert state.client_settings["vector_search"]["index_type"] == "" + + def test_reset_button_with_filtered_dataframe( + self, app_server, sample_vector_store_data, sample_vector_store_data_alt + ): + """Test reset button behavior with dynamically filtered dataframes""" + assert app_server is not None + + # Setup state with a filter already applied + state.client_settings = { + "vector_search": { + "enabled": True, + "model": sample_vector_store_data["model"], + "chunk_size": sample_vector_store_data["chunk_size"], + "chunk_overlap": "", + "distance_metric": "", + "vector_store": "", + "alias": "alias1", # Filter applied + "index_type": "", + "top_k": 10, + "search_type": "Similarity", + "score_threshold": 0.5, + "fetch_k": 20, + "lambda_mult": 0.5, + }, + "database": {"alias": "DEFAULT"}, + } + + state.selected_vector_search_alias = "alias1" + state.selected_vector_search_model = sample_vector_store_data["model"] + state.selected_vector_search_chunk_size = sample_vector_store_data["chunk_size"] + + def mock_button(label, **kwargs): + if "Reset" in label and "on_click" in kwargs: + kwargs["on_click"]() + return True + + with ( + patch.object(st.sidebar, "subheader"), + patch.object(st.sidebar, "button", side_effect=mock_button), + patch.object(st.sidebar, "selectbox", return_value=""), + patch.object(st, "info"), + ): + # Create dataframe with same alias using shared fixtures + vs1 = sample_vector_store_data.copy() + vs1["alias"] = "alias1" + vs2 = sample_vector_store_data_alt.copy() + vs2["alias"] = "alias1" + vs_df = pd.DataFrame([vs1, vs2]) + + with patch.object(st_common, "enabled_models_lookup") as mock_models: + mock_models.return_value = {"openai/text-embed-3": {"id": "text-embed-3"}} + st_common.render_vector_store_selection(vs_df) + + # After reset, all filters should be cleared + assert state.selected_vector_search_alias == "" + assert state.selected_vector_search_model == "" + assert state.selected_vector_search_chunk_size == "" + assert state.client_settings["vector_search"]["alias"] == "" + assert state.client_settings["vector_search"]["model"] == "" + assert state.client_settings["vector_search"]["chunk_size"] == "" diff --git a/tests/client/unit/content/test_chatbot_unit.py b/tests/client/unit/content/test_chatbot_unit.py index 9d3afd0c..66b309ed 100644 --- a/tests/client/unit/content/test_chatbot_unit.py +++ b/tests/client/unit/content/test_chatbot_unit.py @@ -138,7 +138,7 @@ def test_setup_sidebar_no_models(self, monkeypatch): def test_setup_sidebar_with_models(self, monkeypatch): """Test setup_sidebar with enabled language models""" from client.content import chatbot - from client.utils import st_common, vs_options + from client.utils import st_common from streamlit import session_state as state # Mock enabled_models_lookup to return models @@ -148,7 +148,7 @@ def test_setup_sidebar_with_models(self, monkeypatch): monkeypatch.setattr(st_common, "tools_sidebar", MagicMock()) monkeypatch.setattr(st_common, "history_sidebar", MagicMock()) monkeypatch.setattr(st_common, "ll_sidebar", MagicMock()) - monkeypatch.setattr(vs_options, "vector_search_sidebar", MagicMock()) + monkeypatch.setattr(st_common, "vector_search_sidebar", MagicMock()) # Initialize state state.enable_client = True @@ -162,7 +162,7 @@ def test_setup_sidebar_with_models(self, monkeypatch): def test_setup_sidebar_client_disabled(self, monkeypatch): """Test setup_sidebar when client gets disabled""" from client.content import chatbot - from client.utils import st_common, vs_options + from client.utils import st_common from streamlit import session_state as state import streamlit as st @@ -175,7 +175,7 @@ def disable_client(): monkeypatch.setattr(st_common, "tools_sidebar", disable_client) monkeypatch.setattr(st_common, "history_sidebar", MagicMock()) monkeypatch.setattr(st_common, "ll_sidebar", MagicMock()) - monkeypatch.setattr(vs_options, "vector_search_sidebar", MagicMock()) + monkeypatch.setattr(st_common, "vector_search_sidebar", MagicMock()) # Mock st.stop mock_stop = MagicMock(side_effect=SystemExit) diff --git a/tests/client/unit/content/tools/tabs/test_split_embed_unit.py b/tests/client/unit/content/tools/tabs/test_split_embed_unit.py index fcfd6b9d..39bdce27 100644 --- a/tests/client/unit/content/tools/tabs/test_split_embed_unit.py +++ b/tests/client/unit/content/tools/tabs/test_split_embed_unit.py @@ -270,24 +270,27 @@ def test_update_chunk_size_input(self): class TestSplitEmbedEdgeCases: """Tests for edge cases and validation in split_embed implementation""" - def test_chunk_overlap_syncs_slider_to_input(self): + def test_chunk_overlap_validation(self): """ - Test that update_chunk_overlap_input syncs slider value to input. + Test that chunk_overlap should not exceed chunk_size. - The function copies the slider value to the input field. - Note: Validation of overlap < size is handled at the UI level, not in this function. + This validates proper chunk configuration to prevent text splitting issues. + If this test fails, it indicates chunk_overlap is allowed to exceed chunk_size. """ from client.content.tools.tabs.split_embed import update_chunk_overlap_input from streamlit import session_state as state - # Setup state - state.selected_chunk_overlap_slider = 500 + # Setup state with overlap > size (function copies FROM slider TO input) + state.selected_chunk_overlap_slider = 2000 # Overlap (will be copied to input) + state.selected_chunk_size_slider = 1000 # Size (smaller!) # Call function update_chunk_overlap_input() - # Verify the value was copied from slider to input - assert state.selected_chunk_overlap_input == 500 + # EXPECTED: overlap should be capped at chunk_size or validation should prevent this + # If this assertion fails, it exposes lack of validation + assert state.selected_chunk_overlap_input < state.selected_chunk_size_slider, \ + "Chunk overlap should not exceed chunk size" def test_files_data_frame_process_column_added(self): """ diff --git a/tests/client/unit/utils/test_st_common_unit.py b/tests/client/unit/utils/test_st_common_unit.py index 1eee4014..1884dc24 100644 --- a/tests/client/unit/utils/test_st_common_unit.py +++ b/tests/client/unit/utils/test_st_common_unit.py @@ -6,8 +6,10 @@ # spell-checker: disable from io import BytesIO -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch +import pandas as pd +import streamlit as st from streamlit import session_state as state from client.utils import api_call, st_common @@ -392,3 +394,277 @@ def test_is_db_configured_false_different_alias(self, app_server): result = st_common.is_db_configured() assert result is False + + +############################################################################# +# Test Vector Store Helpers +############################################################################# +class TestVectorStoreHelpers: + """Test vector store helper functions""" + + def test_update_filtered_vector_store_no_filters(self, app_server, sample_vector_stores_list): + """Test update_filtered_vector_store with no filters""" + assert app_server is not None + + state.model_configs = [ + {"id": "text-embed-3", "provider": "openai", "type": "embed", "enabled": True}, + ] + + vs_df = pd.DataFrame(sample_vector_stores_list) + + result = st_common.update_filtered_vector_store(vs_df) + + # Should return all rows (filtered by enabled models only) + assert len(result) == 2 + + def test_update_filtered_vector_store_with_alias_filter(self, app_server, sample_vector_stores_list): + """Test update_filtered_vector_store with alias filter""" + assert app_server is not None + + state.model_configs = [ + {"id": "text-embed-3", "provider": "openai", "type": "embed", "enabled": True}, + ] + state.selected_vector_search_alias = "vs1" + + vs_df = pd.DataFrame(sample_vector_stores_list) + + result = st_common.update_filtered_vector_store(vs_df) + + # Should only return vs1 + assert len(result) == 1 + assert result.iloc[0]["alias"] == "vs1" + + def test_update_filtered_vector_store_disabled_model(self, app_server, sample_vector_store_data): + """Test that disabled embedding models filter out vector stores""" + assert app_server is not None + + state.model_configs = [ + {"id": "text-embed-3", "provider": "openai", "type": "embed", "enabled": False}, + ] + + # Use shared fixture with vs1 alias + vs1 = sample_vector_store_data.copy() + vs1["alias"] = "vs1" + vs1.pop("vector_store", None) + vs_df = pd.DataFrame([vs1]) + + result = st_common.update_filtered_vector_store(vs_df) + + # Should return empty (model not enabled) + assert len(result) == 0 + + def test_update_filtered_vector_store_multiple_filters(self, app_server, sample_vector_stores_list): + """Test update_filtered_vector_store with multiple filters""" + assert app_server is not None + + state.model_configs = [ + {"id": "text-embed-3", "provider": "openai", "type": "embed", "enabled": True}, + ] + state.selected_vector_search_alias = "vs1" + state.selected_vector_search_model = "openai/text-embed-3" + state.selected_vector_search_chunk_size = 1000 + + # Use only vs1 entries from the fixture + vs1_entries = [vs.copy() for vs in sample_vector_stores_list] + for vs in vs1_entries: + vs["alias"] = "vs1" + + vs_df = pd.DataFrame(vs1_entries) + + result = st_common.update_filtered_vector_store(vs_df) + + # Should only return the 1000 chunk_size entry + assert len(result) == 1 + assert result.iloc[0]["chunk_size"] == 1000 + + +############################################################################# +# Test _vs_gen_selectbox Function +############################################################################# +class TestVsGenSelectbox: + """Unit tests for the _vs_gen_selectbox function""" + + def test_single_option_auto_select_when_empty(self, app_server): + """Test auto-selection when there's one option and current value is empty""" + assert app_server is not None + + # Setup: empty current value + state.client_settings = {"vector_search": {"alias": ""}} + + with patch.object(st.sidebar, "selectbox") as mock_selectbox: + mock_selectbox.return_value = "single_option" + + st_common._vs_gen_selectbox("Select Alias:", ["single_option"], "selected_vector_search_alias") + + # Verify auto-selection occurred + assert state.client_settings["vector_search"]["alias"] == "single_option" + assert state.selected_vector_search_alias == "single_option" + + # Verify selectbox was called with correct index (1 = first real option after empty) + mock_selectbox.assert_called_once() + call_args = mock_selectbox.call_args + assert call_args[1]["index"] == 1 # Index 1 points to "single_option" in ["", "single_option"] + + def test_single_option_no_auto_select_when_populated(self, app_server): + """Test NO auto-selection when there's one option but value already exists""" + assert app_server is not None + + # Setup: existing value + state.client_settings = {"vector_search": {"alias": "existing_value"}} + + with patch.object(st.sidebar, "selectbox") as mock_selectbox: + mock_selectbox.return_value = "existing_value" + + st_common._vs_gen_selectbox("Select Alias:", ["existing_value"], "selected_vector_search_alias") + + # Value should remain unchanged (not overwritten) + assert state.client_settings["vector_search"]["alias"] == "existing_value" + + # Verify selectbox was called with existing value's index + mock_selectbox.assert_called_once() + call_args = mock_selectbox.call_args + assert call_args[1]["index"] == 1 # existing_value is at index 1 + + def test_multiple_options_no_auto_select(self, app_server): + """Test no auto-selection with multiple options""" + assert app_server is not None + + # Setup: empty value with multiple options + state.client_settings = {"vector_search": {"alias": ""}} + + with patch.object(st.sidebar, "selectbox") as mock_selectbox: + mock_selectbox.return_value = "" + + st_common._vs_gen_selectbox( + "Select Alias:", ["option1", "option2", "option3"], "selected_vector_search_alias" + ) + + # Should remain empty (no auto-selection) + assert state.client_settings["vector_search"]["alias"] == "" + + # Verify selectbox was called with index 0 (empty option) + mock_selectbox.assert_called_once() + call_args = mock_selectbox.call_args + assert call_args[1]["index"] == 0 # Index 0 is the empty option + + def test_no_valid_options_disabled(self, app_server): + """Test selectbox is disabled when no valid options""" + assert app_server is not None + + state.client_settings = {"vector_search": {"alias": ""}} + + with patch.object(st.sidebar, "selectbox") as mock_selectbox: + mock_selectbox.return_value = "" + + st_common._vs_gen_selectbox( + "Select Alias:", + [], # No options + "selected_vector_search_alias", + ) + + # Verify selectbox was called with disabled=True + mock_selectbox.assert_called_once() + call_args = mock_selectbox.call_args + assert call_args[1]["disabled"] is True + assert call_args[1]["index"] == 0 + + def test_invalid_current_value_reset(self, app_server): + """Test that invalid current value is reset to empty""" + assert app_server is not None + + # Setup: value that's not in the options + state.client_settings = {"vector_search": {"alias": "invalid_option"}} + + with patch.object(st.sidebar, "selectbox") as mock_selectbox: + mock_selectbox.return_value = "" + + st_common._vs_gen_selectbox("Select Alias:", ["valid1", "valid2"], "selected_vector_search_alias") + + # Invalid value should not cause error, selectbox should show empty + mock_selectbox.assert_called_once() + call_args = mock_selectbox.call_args + assert call_args[1]["index"] == 0 # Reset to empty option + + +############################################################################# +# Test Reset Button Callback Function +############################################################################# +class TestResetButtonCallback: + """Unit tests for the reset button callback within render_vector_store_selection""" + + def test_reset_clears_correct_fields(self, app_server): + """Test reset callback clears only the specified vector store fields""" + assert app_server is not None + + # Setup initial values + state.client_settings = { + "vector_search": { + "model": "openai/text-embed-3", + "chunk_size": 1000, + "chunk_overlap": 200, + "distance_metric": "cosine", + "vector_store": "vs_test", + "alias": "test_alias", + "index_type": "IVF", + "top_k": 10, + "search_type": "Similarity", + } + } + + # Set widget states + state.selected_vector_search_model = "openai/text-embed-3" + state.selected_vector_search_chunk_size = 1000 + state.selected_vector_search_chunk_overlap = 200 + state.selected_vector_search_distance_metric = "cosine" + state.selected_vector_search_alias = "test_alias" + state.selected_vector_search_index_type = "IVF" + + # Define and execute reset logic (simulating the reset callback) + fields_to_reset = [ + "model", + "chunk_size", + "chunk_overlap", + "distance_metric", + "vector_store", + "alias", + "index_type", + ] + for key in fields_to_reset: + widget_key = f"selected_vector_search_{key}" + state[widget_key] = "" + state.client_settings["vector_search"][key] = "" + + # Verify the correct fields were cleared + for field in fields_to_reset: + assert state.client_settings["vector_search"][field] == "" + assert state[f"selected_vector_search_{field}"] == "" + + # Verify other fields were NOT cleared + assert state.client_settings["vector_search"]["top_k"] == 10 + assert state.client_settings["vector_search"]["search_type"] == "Similarity" + + def test_reset_enables_auto_population(self, app_server): + """Test that reset creates conditions for auto-population""" + assert app_server is not None + + # Setup with existing values + state.client_settings = {"vector_search": {"alias": "existing"}} + state.selected_vector_search_alias = "existing" + + # Execute reset logic + state.selected_vector_search_alias = "" + state.client_settings["vector_search"]["alias"] = "" + + # After reset, fields should be empty (ready for auto-population) + assert state.client_settings["vector_search"]["alias"] == "" + assert state.selected_vector_search_alias == "" + + # Now when _vs_gen_selectbox is called with a single option, it should auto-populate + with patch.object(st.sidebar, "selectbox") as mock_selectbox: + mock_selectbox.return_value = "auto_selected" + + st_common._vs_gen_selectbox("Select Alias:", ["auto_selected"], "selected_vector_search_alias") + + # Verify auto-population happened + assert state.client_settings["vector_search"]["alias"] == "auto_selected" + assert state.selected_vector_search_alias == "auto_selected" diff --git a/test/opentofu/OMRMetaSchema.yaml b/tests/opentofu/OMRMetaSchema.yaml similarity index 100% rename from test/opentofu/OMRMetaSchema.yaml rename to tests/opentofu/OMRMetaSchema.yaml diff --git a/test/opentofu/validate_omr_schema.py b/tests/opentofu/validate_omr_schema.py similarity index 100% rename from test/opentofu/validate_omr_schema.py rename to tests/opentofu/validate_omr_schema.py diff --git a/tests/server/integration/test_endpoints_settings.py b/tests/server/integration/test_endpoints_settings.py index 2eba8f74..5cfde6c0 100644 --- a/tests/server/integration/test_endpoints_settings.py +++ b/tests/server/integration/test_endpoints_settings.py @@ -105,8 +105,7 @@ def test_settings_update(self, client, auth_headers): updated_settings = Settings( client="default", ll_model=LargeLanguageSettings(model="updated-model", chat_history=False), - tools_enabled=["Vector Search"], - vector_search=VectorSearchSettings(grade=False, search_type="Similarity", top_k=5), + vector_search=VectorSearchSettings(enabled=True, grading=False, search_type="Similarity", top_k=5), oci=OciSettings(auth_profile="UPDATED"), ) @@ -126,8 +125,8 @@ def test_settings_update(self, client, auth_headers): # Check that the values were updated assert new_settings["ll_model"]["model"] == "updated-model" assert new_settings["ll_model"]["chat_history"] is False - assert new_settings["tools_enabled"] == ["Vector Search"] - assert new_settings["vector_search"]["grade"] is False + assert new_settings["vector_search"]["enabled"] is True + assert new_settings["vector_search"]["grading"] is False assert new_settings["vector_search"]["top_k"] == 5 assert new_settings["oci"]["auth_profile"] == "UPDATED" diff --git a/tests/server/unit/api/utils/test_utils_chat.py b/tests/server/unit/api/utils/test_utils_chat.py index b614a00e..12e8f662 100644 --- a/tests/server/unit/api/utils/test_utils_chat.py +++ b/tests/server/unit/api/utils/test_utils_chat.py @@ -23,23 +23,14 @@ class TestChatUtils: """Test chat utility functions""" - @pytest.fixture - def sample_message(self): - """Sample chat message fixture""" - return ChatMessage(role="user", content="Hello, how are you?") - - @pytest.fixture - def sample_request(self, sample_message): - """Sample chat request fixture""" - return ChatRequest(messages=[sample_message], model="openai/gpt-4") - - @pytest.fixture - def sample_client_settings(self): - """Sample client settings fixture""" - return Settings( + def __init__(self): + """Setup test data""" + self.sample_message = ChatMessage(role="user", content="Hello, how are you?") + self.sample_request = ChatRequest(messages=[self.sample_message], model="openai/gpt-4") + self.sample_client_settings = Settings( client="test_client", ll_model=LargeLanguageSettings(model="openai/gpt-4", chat_history=True, temperature=0.7, max_tokens=4096), - vector_search=VectorSearchSettings(), + vector_search=VectorSearchSettings(enabled=False), oci=OciSettings(auth_profile="DEFAULT"), ) @@ -49,12 +40,11 @@ def sample_client_settings(self): @patch("server.agents.chatbot.chatbot_graph.astream") @pytest.mark.asyncio async def test_completion_generator_success( - self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client, - sample_request, sample_client_settings + self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client ): """Test successful completion generation""" # Setup mocks - mock_get_client.return_value = sample_client_settings + mock_get_client.return_value = self.sample_client_settings mock_get_oci.return_value = MagicMock() mock_get_litellm_config.return_value = {"model": "gpt-4", "temperature": 0.7} @@ -68,7 +58,7 @@ async def mock_generator(): # Test the function results = [] - async for result in chat.completion_generator("test_client", sample_request, "completions"): + async for result in chat.completion_generator("test_client", self.sample_request, "completions"): results.append(result) # Verify results - for "completions" mode, we get stream chunks + final completion @@ -85,12 +75,11 @@ async def mock_generator(): @patch("server.agents.chatbot.chatbot_graph.astream") @pytest.mark.asyncio async def test_completion_generator_streaming( - self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client, - sample_request, sample_client_settings + self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client ): """Test streaming completion generation""" # Setup mocks - mock_get_client.return_value = sample_client_settings + mock_get_client.return_value = self.sample_client_settings mock_get_oci.return_value = MagicMock() mock_get_litellm_config.return_value = {"model": "gpt-4", "temperature": 0.7} @@ -104,7 +93,7 @@ async def mock_generator(): # Test the function results = [] - async for result in chat.completion_generator("test_client", sample_request, "streams"): + async for result in chat.completion_generator("test_client", self.sample_request, "streams"): results.append(result) # Verify results - should include encoded stream chunks and finish marker @@ -128,13 +117,11 @@ async def test_completion_generator_with_vector_search( mock_get_litellm_config, mock_get_oci, mock_get_client, - sample_request, - sample_client_settings, ): """Test completion generation with vector search enabled""" - # Setup settings with vector search enabled via tools_enabled - vector_search_settings = sample_client_settings.model_copy() - vector_search_settings.tools_enabled = ["Vector Search"] + # Setup settings with vector search enabled + vector_search_settings = self.sample_client_settings.model_copy() + vector_search_settings.vector_search.enabled = True # Setup mocks mock_get_client.return_value = vector_search_settings @@ -154,7 +141,7 @@ async def mock_generator(): # Test the function results = [] - async for result in chat.completion_generator("test_client", sample_request, "completions"): + async for result in chat.completion_generator("test_client", self.sample_request, "completions"): results.append(result) # Verify vector search setup @@ -168,15 +155,14 @@ async def mock_generator(): @patch("server.agents.chatbot.chatbot_graph.astream") @pytest.mark.asyncio async def test_completion_generator_no_model_specified( - self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client, - sample_message, sample_client_settings + self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client ): """Test completion generation when no model is specified in request""" # Create request without model - request_no_model = ChatRequest(messages=[sample_message], model=None) + request_no_model = ChatRequest(messages=[self.sample_message], model=None) # Setup mocks - mock_get_client.return_value = sample_client_settings + mock_get_client.return_value = self.sample_client_settings mock_get_oci.return_value = MagicMock() mock_get_litellm_config.return_value = {"model": "gpt-4", "temperature": 0.7} diff --git a/tests/server/unit/api/utils/test_utils_databases_crud.py b/tests/server/unit/api/utils/test_utils_databases_crud.py index 62c06ef0..f50d0a7d 100644 --- a/tests/server/unit/api/utils/test_utils_databases_crud.py +++ b/tests/server/unit/api/utils/test_utils_databases_crud.py @@ -17,8 +17,10 @@ class TestDatabases: """Test databases module functionality""" - sample_database: Database - sample_database_2: Database + def __init__(self): + """Initialize test data""" + self.sample_database = None + self.sample_database_2 = None def setup_method(self): """Setup test data before each test""" @@ -27,9 +29,36 @@ def setup_method(self): name="test_db_2", user="test_user_2", password="test_password_2", dsn="test_dsn_2" ) - # test_get_all: See test/unit/server/api/utils/test_utils_databases.py::TestGet::test_get_all_databases - # test_get_by_name_found: See test/unit/server/api/utils/test_utils_databases.py::TestGet::test_get_specific_database - # test_get_by_name_not_found: See test/unit/server/api/utils/test_utils_databases.py::TestGet::test_get_raises_unknown_error + @patch("server.api.utils.databases.DATABASE_OBJECTS") + def test_get_all(self, mock_database_objects): + """Test getting all databases when no name is provided""" + mock_database_objects.__iter__ = MagicMock(return_value=iter([self.sample_database, self.sample_database_2])) + mock_database_objects.__len__ = MagicMock(return_value=2) + + result = databases.get() + + assert result == [self.sample_database, self.sample_database_2] + assert len(result) == 2 + + @patch("server.api.utils.databases.DATABASE_OBJECTS") + def test_get_by_name_found(self, mock_database_objects): + """Test getting database by name when it exists""" + mock_database_objects.__iter__ = MagicMock(return_value=iter([self.sample_database, self.sample_database_2])) + mock_database_objects.__len__ = MagicMock(return_value=2) + + result = databases.get(name="test_db") + + assert result == [self.sample_database] + assert len(result) == 1 + + @patch("server.api.utils.databases.DATABASE_OBJECTS") + def test_get_by_name_not_found(self, mock_database_objects): + """Test getting database by name when it doesn't exist""" + mock_database_objects.__iter__ = MagicMock(return_value=iter([self.sample_database])) + mock_database_objects.__len__ = MagicMock(return_value=1) + + with pytest.raises(ValueError, match="nonexistent not found"): + databases.get(name="nonexistent") @patch("server.api.utils.databases.DATABASE_OBJECTS") def test_get_empty_list(self, mock_database_objects): @@ -50,9 +79,54 @@ def test_get_empty_list_with_name(self, mock_database_objects): with pytest.raises(ValueError, match="test_db not found"): databases.get(name="test_db") - # test_create_success: See test/unit/server/api/utils/test_utils_databases.py::TestCreate::test_create_success - # test_create_already_exists: See test/unit/server/api/utils/test_utils_databases.py::TestCreate::test_create_raises_exists_error - # test_create_missing_user: See test/unit/server/api/utils/test_utils_databases.py::TestCreate::test_create_raises_value_error_missing_fields + def test_create_success(self, db_container, db_objects_manager): + """Test successful database creation when database doesn't exist""" + assert db_container is not None + assert db_objects_manager is not None + # Clear the list to start fresh + databases.DATABASE_OBJECTS.clear() + + # Create a new database + new_database = Database(name="new_test_db", user="test_user", password="test_password", dsn="test_dsn") + + result = databases.create(new_database) + + # Verify database was added + assert len(databases.DATABASE_OBJECTS) == 1 + assert databases.DATABASE_OBJECTS[0].name == "new_test_db" + assert result == [new_database] + + def test_create_already_exists(self, db_container, db_objects_manager): + """Test database creation when database already exists""" + assert db_container is not None + assert db_objects_manager is not None + # Add a database to the list + databases.DATABASE_OBJECTS.clear() + existing_db = Database(name="existing_db", user="test_user", password="test_password", dsn="test_dsn") + databases.DATABASE_OBJECTS.append(existing_db) + + # Try to create a database with the same name + duplicate_db = Database(name="existing_db", user="other_user", password="other_password", dsn="other_dsn") + + # Should raise an error for duplicate database + with pytest.raises(ValueError, match="Database: existing_db already exists"): + databases.create(duplicate_db) + + # Verify only original database exists + assert len(databases.DATABASE_OBJECTS) == 1 + assert databases.DATABASE_OBJECTS[0] == existing_db + + def test_create_missing_user(self, db_container, db_objects_manager): + """Test database creation with missing user field""" + assert db_container is not None + assert db_objects_manager is not None + databases.DATABASE_OBJECTS.clear() + + # Create database with missing user + incomplete_db = Database(name="incomplete_db", password="test_password", dsn="test_dsn") + + with pytest.raises(ValueError, match="'user', 'password', and 'dsn' are required"): + databases.create(incomplete_db) def test_create_missing_password(self, db_container, db_objects_manager): """Test database creation with missing password field""" @@ -90,7 +164,27 @@ def test_create_multiple_missing_fields(self, db_container, db_objects_manager): with pytest.raises(ValueError, match="'user', 'password', and 'dsn' are required"): databases.create(incomplete_db) - # test_delete: See test/unit/server/api/utils/test_utils_databases.py::TestDelete::test_delete_removes_database + def test_delete(self, db_container, db_objects_manager): + """Test database deletion""" + assert db_container is not None + assert db_objects_manager is not None + # Setup test data + db1 = Database(name="test_db_1", user="user1", password="pass1", dsn="dsn1") + db2 = Database(name="test_db_2", user="user2", password="pass2", dsn="dsn2") + db3 = Database(name="test_db_3", user="user3", password="pass3", dsn="dsn3") + + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend([db1, db2, db3]) + + # Delete middle database + databases.delete("test_db_2") + + # Verify deletion + assert len(databases.DATABASE_OBJECTS) == 2 + names = [db.name for db in databases.DATABASE_OBJECTS] + assert "test_db_1" in names + assert "test_db_2" not in names + assert "test_db_3" in names def test_delete_nonexistent(self, db_container, db_objects_manager): """Test deleting non-existent database""" @@ -142,7 +236,10 @@ def test_delete_multiple_same_name(self, db_container, db_objects_manager): assert len(databases.DATABASE_OBJECTS) == 1 assert databases.DATABASE_OBJECTS[0].name == "other" - # test_logger_exists: See test/unit/server/api/utils/test_utils_databases.py::TestLoggerConfiguration::test_logger_exists + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(databases, "logger") + assert databases.logger.name == "api.utils.database" def test_get_filters_correctly(self, db_container, db_objects_manager): """Test that get correctly filters by name""" @@ -226,7 +323,12 @@ def test_create_real_scenario(self, db_container, db_objects_manager): class TestDbException: """Test custom database exception class""" - # test_db_exception_initialization: See test/unit/server/api/utils/test_utils_databases.py::TestDbException::test_db_exception_init + def test_db_exception_initialization(self): + """Test DbException initialization""" + exc = DbException(status_code=500, detail="Database error") + assert exc.status_code == 500 + assert exc.detail == "Database error" + assert str(exc) == "Database error" def test_db_exception_inheritance(self): """Test DbException inherits from Exception""" diff --git a/tests/server/unit/api/utils/test_utils_databases_functions.py b/tests/server/unit/api/utils/test_utils_databases_functions.py index a4ac6b74..e79f1d42 100644 --- a/tests/server/unit/api/utils/test_utils_databases_functions.py +++ b/tests/server/unit/api/utils/test_utils_databases_functions.py @@ -20,7 +20,9 @@ class TestDatabaseUtilsPrivateFunctions: """Test private utility functions""" - sample_database: Database + def __init__(self): + """Initialize test data""" + self.sample_database = None def setup_method(self): """Setup test data""" @@ -31,13 +33,90 @@ def setup_method(self): dsn=TEST_CONFIG["db_dsn"], ) - # test_test_function_success: See test/unit/server/api/utils/test_utils_databases.py::TestTestConnection::test_test_connection_active - # test_test_function_reconnect: See test/unit/server/api/utils/test_utils_databases.py::TestTestConnection::test_test_connection_refreshes_on_database_error - # test_test_function_value_error: See test/unit/server/api/utils/test_utils_databases.py::TestTestConnection::test_test_raises_db_exception_on_value_error - # test_test_function_permission_error: See test/unit/server/api/utils/test_utils_databases.py::TestTestConnection::test_test_raises_db_exception_on_permission_error - # test_test_function_connection_error: See test/unit/server/api/utils/test_utils_databases.py::TestTestConnection::test_test_raises_db_exception_on_connection_error - # test_test_function_generic_exception: See test/unit/server/api/utils/test_utils_databases.py::TestTestConnection::test_test_raises_db_exception_on_generic_exception - # test_get_vs_with_real_database: See test/unit/server/api/utils/test_utils_databases.py::TestGetVs::test_get_vs_returns_list + def test_test_function_success(self, db_container): + """Test successful database connection test with real database""" + assert db_container is not None + # Connect to real database + conn = databases.connect(self.sample_database) + self.sample_database.set_connection(conn) + + try: + # Test the connection + databases._test(self.sample_database) + assert self.sample_database.connected is True + finally: + databases.disconnect(conn) + + @patch("oracledb.Connection") + def test_test_function_reconnect(self, mock_connection): + """Test database reconnection when ping fails""" + mock_connection.ping.side_effect = oracledb.DatabaseError("Connection lost") + self.sample_database.set_connection(mock_connection) + + with patch("server.api.utils.databases.connect") as mock_connect: + databases._test(self.sample_database) + mock_connect.assert_called_once_with(self.sample_database) + + @patch("oracledb.Connection") + def test_test_function_value_error(self, mock_connection): + """Test handling of value errors""" + mock_connection.ping.side_effect = ValueError("Invalid value") + self.sample_database.set_connection(mock_connection) + + with pytest.raises(DbException) as exc_info: + databases._test(self.sample_database) + + assert exc_info.value.status_code == 400 + assert "Database: Invalid value" in str(exc_info.value) + + @patch("oracledb.Connection") + def test_test_function_permission_error(self, mock_connection): + """Test handling of permission errors""" + mock_connection.ping.side_effect = PermissionError("Access denied") + self.sample_database.set_connection(mock_connection) + + with pytest.raises(DbException) as exc_info: + databases._test(self.sample_database) + + assert exc_info.value.status_code == 401 + assert "Database: Access denied" in str(exc_info.value) + + @patch("oracledb.Connection") + def test_test_function_connection_error(self, mock_connection): + """Test handling of connection errors""" + mock_connection.ping.side_effect = ConnectionError("Connection failed") + self.sample_database.set_connection(mock_connection) + + with pytest.raises(DbException) as exc_info: + databases._test(self.sample_database) + + assert exc_info.value.status_code == 503 + assert "Database: Connection failed" in str(exc_info.value) + + @patch("oracledb.Connection") + def test_test_function_generic_exception(self, mock_connection): + """Test handling of generic exceptions""" + mock_connection.ping.side_effect = RuntimeError("Unknown error") + self.sample_database.set_connection(mock_connection) + + with pytest.raises(DbException) as exc_info: + databases._test(self.sample_database) + + assert exc_info.value.status_code == 500 + assert "Unknown error" in str(exc_info.value) + + def test_get_vs_with_real_database(self, db_container): + """Test vector storage retrieval with real database""" + assert db_container is not None + conn = databases.connect(self.sample_database) + + try: + # Test with empty result (no vector stores initially) + result = databases._get_vs(conn) + assert isinstance(result, list) + assert len(result) == 0 # Initially no vector stores + finally: + databases.disconnect(conn) @patch("server.api.utils.databases.execute_sql") def test_get_vs_with_mock_data(self, mock_execute_sql): @@ -93,7 +172,9 @@ def test_get_vs_malformed_json(self, mock_execute_sql): class TestDatabaseUtilsPublicFunctions: """Test public utility functions - connection and execution""" - sample_database: Database + def __init__(self): + """Initialize test data""" + self.sample_database = None def setup_method(self): """Setup test data""" @@ -104,10 +185,54 @@ def setup_method(self): dsn=TEST_CONFIG["db_dsn"], ) - # test_connect_success_with_real_database: See test/unit/server/api/utils/test_utils_databases.py::TestConnect::test_connect_success_real_db - # test_connect_missing_user: See test/unit/server/api/utils/test_utils_databases.py::TestConnect::test_connect_raises_value_error_missing_details - # test_connect_missing_password: See test/unit/server/api/utils/test_utils_databases.py::TestConnect::test_connect_raises_value_error_missing_details - # test_connect_missing_dsn: See test/unit/server/api/utils/test_utils_databases.py::TestConnect::test_connect_raises_value_error_missing_details + def test_connect_success_with_real_database(self, db_container): + """Test successful database connection with real database""" + assert db_container is not None + result = databases.connect(self.sample_database) + + try: + assert result is not None + assert isinstance(result, oracledb.Connection) + # Test that connection is active + result.ping() + finally: + databases.disconnect(result) + + def test_connect_missing_user(self): + """Test connection with missing user""" + incomplete_db = Database( + name="test_db", + user="", # Missing user + password=TEST_CONFIG["db_password"], + dsn=TEST_CONFIG["db_dsn"], + ) + + with pytest.raises(ValueError, match="missing connection details"): + databases.connect(incomplete_db) + + def test_connect_missing_password(self): + """Test connection with missing password""" + incomplete_db = Database( + name="test_db", + user=TEST_CONFIG["db_username"], + password="", # Missing password + dsn=TEST_CONFIG["db_dsn"], + ) + + with pytest.raises(ValueError, match="missing connection details"): + databases.connect(incomplete_db) + + def test_connect_missing_dsn(self): + """Test connection with missing DSN""" + incomplete_db = Database( + name="test_db", + user=TEST_CONFIG["db_username"], + password=TEST_CONFIG["db_password"], + dsn="", # Missing DSN + ) + + with pytest.raises(ValueError, match="missing connection details"): + databases.connect(incomplete_db) def test_connect_with_wallet_configuration(self, db_container): """Test connection with wallet configuration""" @@ -150,7 +275,18 @@ def test_connect_wallet_password_without_location(self, db_container): # Expected if wallet doesn't exist pass - # test_connect_invalid_credentials: See test/unit/server/api/utils/test_utils_databases.py::TestConnect::test_connect_raises_permission_error_invalid_credentials + def test_connect_invalid_credentials(self, db_container): + """Test connection with invalid credentials""" + assert db_container is not None + invalid_db = Database( + name="test_db", + user="invalid_user", + password="invalid_password", + dsn=TEST_CONFIG["db_dsn"], + ) + + with pytest.raises(PermissionError): + databases.connect(invalid_db) def test_connect_invalid_dsn(self, db_container): """Test connection with invalid DSN""" @@ -166,9 +302,45 @@ def test_connect_invalid_dsn(self, db_container): with pytest.raises(Exception): # Catch any exception - DNS resolution errors vary by environment databases.connect(invalid_db) - # test_disconnect_success: See test/unit/server/api/utils/test_utils_databases.py::TestDisconnect::test_disconnect_closes_connection - # test_execute_sql_success_with_real_database: See test/unit/server/api/utils/test_utils_databases.py::TestExecuteSql::test_execute_sql_returns_rows - # test_execute_sql_with_binds: See test/unit/server/api/utils/test_utils_databases.py::TestExecuteSql::test_execute_sql_with_binds + def test_disconnect_success(self, db_container): + """Test successful database disconnection""" + assert db_container is not None + conn = databases.connect(self.sample_database) + + result = databases.disconnect(conn) + + assert result is None + # Try to use connection after disconnect - should fail + with pytest.raises(oracledb.InterfaceError): + conn.ping() + + def test_execute_sql_success_with_real_database(self, db_container): + """Test successful SQL execution with real database""" + assert db_container is not None + conn = databases.connect(self.sample_database) + + try: + # Test simple query + result = databases.execute_sql(conn, "SELECT 1 FROM DUAL") + assert result is not None + assert len(result) == 1 + assert result[0][0] == 1 + finally: + databases.disconnect(conn) + + def test_execute_sql_with_binds(self, db_container): + """Test SQL execution with bind variables using real database""" + assert db_container is not None + conn = databases.connect(self.sample_database) + + try: + binds = {"test_value": 42} + result = databases.execute_sql(conn, "SELECT :test_value FROM DUAL", binds) + assert result is not None + assert len(result) == 1 + assert result[0][0] == 42 + finally: + databases.disconnect(conn) def test_execute_sql_no_rows(self, db_container): """Test SQL execution that returns no rows""" @@ -243,20 +415,39 @@ def test_execute_sql_table_not_exists_error(self, db_container): finally: databases.disconnect(conn) - # test_execute_sql_invalid_syntax: See test/unit/server/api/utils/test_utils_databases.py::TestExecuteSql::test_execute_sql_raises_on_other_database_error + def test_execute_sql_invalid_syntax(self, db_container): + """Test SQL execution with invalid syntax""" + assert db_container is not None + conn = databases.connect(self.sample_database) + + try: + with pytest.raises(oracledb.DatabaseError): + databases.execute_sql(conn, "INVALID SQL STATEMENT") + finally: + databases.disconnect(conn) def test_drop_vs_function_exists(self): """Test that drop_vs function exists and is callable""" assert hasattr(databases, "drop_vs") assert callable(databases.drop_vs) - # test_drop_vs_calls_langchain: See test/unit/server/api/utils/test_utils_databases.py::TestDropVs::test_drop_vs_calls_langchain + @patch("langchain_community.vectorstores.oraclevs.drop_table_purge") + def test_drop_vs_calls_langchain(self, mock_drop_table): + """Test drop_vs calls LangChain drop_table_purge""" + mock_connection = MagicMock() + vs_name = "TEST_VECTOR_STORE" + + databases.drop_vs(mock_connection, vs_name) + + mock_drop_table.assert_called_once_with(mock_connection, vs_name) class TestDatabaseUtilsQueryFunctions: """Test public utility functions - get and client database functions""" - sample_database: Database + def __init__(self): + """Initialize test data""" + self.sample_database = None def setup_method(self): """Setup test data""" @@ -410,4 +601,7 @@ def test_get_client_database_with_validation(self, mock_get_settings, db_contain if db.connection: databases.disconnect(db.connection) - # test_logger_exists: See test/unit/server/api/utils/test_utils_databases.py::TestLoggerConfiguration::test_logger_exists + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(databases, "logger") + assert databases.logger.name == "api.utils.database" diff --git a/tests/server/unit/api/utils/test_utils_embed.py b/tests/server/unit/api/utils/test_utils_embed.py index 13a63538..161aedc4 100644 --- a/tests/server/unit/api/utils/test_utils_embed.py +++ b/tests/server/unit/api/utils/test_utils_embed.py @@ -9,7 +9,6 @@ from pathlib import Path from unittest.mock import patch, mock_open, MagicMock -import pytest from langchain.docstore.document import Document as LangchainDocument from server.api.utils import embed @@ -19,17 +18,12 @@ class TestEmbedUtils: """Test embed utility functions""" - @pytest.fixture - def sample_document(self): - """Sample document fixture""" - return LangchainDocument( + def __init__(self): + """Setup test data""" + self.sample_document = LangchainDocument( page_content="This is a test document content.", metadata={"source": "/path/to/test_file.txt", "page": 1} ) - - @pytest.fixture - def sample_split_doc(self): - """Sample split document fixture""" - return LangchainDocument( + self.sample_split_doc = LangchainDocument( page_content="This is a chunk of content.", metadata={"source": "/path/to/test_file.txt", "start_index": 0} ) @@ -60,12 +54,12 @@ def test_get_temp_directory_tmp_fallback(self, mock_mkdir, mock_exists): @patch("builtins.open", new_callable=mock_open) @patch("os.path.getsize") @patch("json.dumps") - def test_doc_to_json_default_output(self, mock_json_dumps, mock_getsize, mock_file, sample_document): + def test_doc_to_json_default_output(self, mock_json_dumps, mock_getsize, mock_file): """Test document to JSON conversion with default output directory""" mock_json_dumps.return_value = '{"test": "data"}' mock_getsize.return_value = 100 - result = embed.doc_to_json([sample_document], "/path/to/test_file.txt", "/tmp") + result = embed.doc_to_json([self.sample_document], "/path/to/test_file.txt", "/tmp") mock_file.assert_called_once() mock_json_dumps.assert_called_once() @@ -75,12 +69,12 @@ def test_doc_to_json_default_output(self, mock_json_dumps, mock_getsize, mock_fi @patch("builtins.open", new_callable=mock_open) @patch("os.path.getsize") @patch("json.dumps") - def test_doc_to_json_custom_output(self, mock_json_dumps, mock_getsize, mock_file, sample_document): + def test_doc_to_json_custom_output(self, mock_json_dumps, mock_getsize, mock_file): """Test document to JSON conversion with custom output directory""" mock_json_dumps.return_value = '{"test": "data"}' mock_getsize.return_value = 100 - result = embed.doc_to_json([sample_document], "/path/to/test_file.txt", "/custom/output") + result = embed.doc_to_json([self.sample_document], "/path/to/test_file.txt", "/custom/output") mock_file.assert_called_once() mock_json_dumps.assert_called_once() @@ -96,10 +90,9 @@ def test_logger_exists(self): class TestGetVectorStoreFiles: """Test get_vector_store_files() function""" - @pytest.fixture - def sample_db(self): - """Sample database fixture""" - return Database( + def __init__(self): + """Setup test data""" + self.sample_db = Database( name="TEST_DB", user="test_user", password="", @@ -108,7 +101,7 @@ def sample_db(self): @patch("server.api.utils.databases.connect") @patch("server.api.utils.databases.disconnect") - def test_get_vector_store_files_with_metadata(self, mock_disconnect, mock_connect, sample_db): + def test_get_vector_store_files_with_metadata(self, mock_disconnect, mock_connect): """Test retrieving file list with complete metadata""" # Mock database connection and cursor mock_conn = MagicMock() @@ -139,7 +132,7 @@ def test_get_vector_store_files_with_metadata(self, mock_disconnect, mock_connec ] # Execute - result = embed.get_vector_store_files(sample_db, "TEST_VS") + result = embed.get_vector_store_files(self.sample_db, "TEST_VS") # Verify assert result["vector_store"] == "TEST_VS" @@ -159,7 +152,7 @@ def test_get_vector_store_files_with_metadata(self, mock_disconnect, mock_connec @patch("server.api.utils.databases.connect") @patch("server.api.utils.databases.disconnect") - def test_get_vector_store_files_with_decimal_size(self, _mock_disconnect, mock_connect, sample_db): + def test_get_vector_store_files_with_decimal_size(self, _mock_disconnect, mock_connect): """Test handling of Decimal size from Oracle NUMBER type""" # Mock database connection mock_conn = MagicMock() @@ -178,7 +171,7 @@ def test_get_vector_store_files_with_decimal_size(self, _mock_disconnect, mock_c ] # Execute - result = embed.get_vector_store_files(sample_db, "TEST_VS") + result = embed.get_vector_store_files(self.sample_db, "TEST_VS") # Verify Decimal was converted to int assert result["files"][0]["size"] == 1024000 @@ -186,7 +179,7 @@ def test_get_vector_store_files_with_decimal_size(self, _mock_disconnect, mock_c @patch("server.api.utils.databases.connect") @patch("server.api.utils.databases.disconnect") - def test_get_vector_store_files_old_format(self, _mock_disconnect, mock_connect, sample_db): + def test_get_vector_store_files_old_format(self, _mock_disconnect, mock_connect): """Test retrieving files with old metadata format (source field)""" # Mock database connection mock_conn = MagicMock() @@ -201,7 +194,7 @@ def test_get_vector_store_files_old_format(self, _mock_disconnect, mock_connect, ] # Execute - result = embed.get_vector_store_files(sample_db, "TEST_VS") + result = embed.get_vector_store_files(self.sample_db, "TEST_VS") # Verify fallback to source field worked assert result["total_files"] == 1 @@ -210,7 +203,7 @@ def test_get_vector_store_files_old_format(self, _mock_disconnect, mock_connect, @patch("server.api.utils.databases.connect") @patch("server.api.utils.databases.disconnect") - def test_get_vector_store_files_with_orphaned_chunks(self, _mock_disconnect, mock_connect, sample_db): + def test_get_vector_store_files_with_orphaned_chunks(self, _mock_disconnect, mock_connect): """Test detection of orphaned chunks without valid filename""" # Mock database connection mock_conn = MagicMock() @@ -227,7 +220,7 @@ def test_get_vector_store_files_with_orphaned_chunks(self, _mock_disconnect, moc ] # Execute - result = embed.get_vector_store_files(sample_db, "TEST_VS") + result = embed.get_vector_store_files(self.sample_db, "TEST_VS") # Verify assert result["total_files"] == 1 @@ -237,7 +230,7 @@ def test_get_vector_store_files_with_orphaned_chunks(self, _mock_disconnect, moc @patch("server.api.utils.databases.connect") @patch("server.api.utils.databases.disconnect") - def test_get_vector_store_files_empty_store(self, _mock_disconnect, mock_connect, sample_db): + def test_get_vector_store_files_empty_store(self, _mock_disconnect, mock_connect): """Test retrieving from empty vector store""" # Mock database connection mock_conn = MagicMock() @@ -249,7 +242,7 @@ def test_get_vector_store_files_empty_store(self, _mock_disconnect, mock_connect mock_cursor.fetchall.return_value = [] # Execute - result = embed.get_vector_store_files(sample_db, "EMPTY_VS") + result = embed.get_vector_store_files(self.sample_db, "EMPTY_VS") # Verify assert result["vector_store"] == "EMPTY_VS" @@ -260,7 +253,7 @@ def test_get_vector_store_files_empty_store(self, _mock_disconnect, mock_connect @patch("server.api.utils.databases.connect") @patch("server.api.utils.databases.disconnect") - def test_get_vector_store_files_sorts_by_filename(self, _mock_disconnect, mock_connect, sample_db): + def test_get_vector_store_files_sorts_by_filename(self, _mock_disconnect, mock_connect): """Test that files are sorted alphabetically by filename""" # Mock database connection mock_conn = MagicMock() @@ -276,7 +269,7 @@ def test_get_vector_store_files_sorts_by_filename(self, _mock_disconnect, mock_c ] # Execute - result = embed.get_vector_store_files(sample_db, "TEST_VS") + result = embed.get_vector_store_files(self.sample_db, "TEST_VS") # Verify sorted order filenames = [f["filename"] for f in result["files"]] diff --git a/tests/server/unit/api/utils/test_utils_models.py b/tests/server/unit/api/utils/test_utils_models.py index d7451dde..ef1a2f3c 100644 --- a/tests/server/unit/api/utils/test_utils_models.py +++ b/tests/server/unit/api/utils/test_utils_models.py @@ -21,11 +21,29 @@ class TestModelsExceptions: """Test custom exception classes""" - # test_url_unreachable_error: See test/unit/server/api/utils/test_utils_models.py::TestExceptions::test_url_unreachable_error_is_value_error - # test_invalid_model_error: See test/unit/server/api/utils/test_utils_models.py::TestExceptions::test_invalid_model_error_is_value_error - # test_exists_model_error: See test/unit/server/api/utils/test_utils_models.py::TestExceptions::test_exists_model_error_is_value_error - # test_unknown_model_error: See test/unit/server/api/utils/test_utils_models.py::TestExceptions::test_unknown_model_error_is_value_error - pass + def test_url_unreachable_error(self): + """Test URLUnreachableError exception""" + error = URLUnreachableError("URL is unreachable") + assert str(error) == "URL is unreachable" + assert isinstance(error, ValueError) + + def test_invalid_model_error(self): + """Test InvalidModelError exception""" + error = InvalidModelError("Invalid model data") + assert str(error) == "Invalid model data" + assert isinstance(error, ValueError) + + def test_exists_model_error(self): + """Test ExistsModelError exception""" + error = ExistsModelError("Model already exists") + assert str(error) == "Model already exists" + assert isinstance(error, ValueError) + + def test_unknown_model_error(self): + """Test UnknownModelError exception""" + error = UnknownModelError("Model not found") + assert str(error) == "Model not found" + assert isinstance(error, ValueError) ##################################################### @@ -34,45 +52,94 @@ class TestModelsExceptions: class TestModelsCRUD: """Test models module functionality""" - @pytest.fixture - def sample_model(self): - """Sample model fixture""" - return Model( + def __init__(self): + """Setup test data for all tests""" + self.sample_model = Model( id="test-model", provider="openai", type="ll", enabled=True, api_base="https://api.openai.com" ) + self.disabled_model = Model(id="disabled-model", provider="anthropic", type="ll", enabled=False) - @pytest.fixture - def disabled_model(self): - """Disabled model fixture""" - return Model(id="disabled-model", provider="anthropic", type="ll", enabled=False) + @patch("server.api.utils.models.MODEL_OBJECTS") + def test_get_model_all_models(self, mock_model_objects): + """Test getting all models without filters""" + mock_model_objects.__iter__ = MagicMock(return_value=iter([self.sample_model, self.disabled_model])) + mock_model_objects.__len__ = MagicMock(return_value=2) + + result = models.get() - # test_get_model_all_models: See test/unit/server/api/utils/test_utils_models.py::TestGet::test_get_all_models + assert result == [self.sample_model, self.disabled_model] @patch("server.api.utils.models.MODEL_OBJECTS") - def test_get_model_by_id_found(self, mock_model_objects, sample_model): + def test_get_model_by_id_found(self, mock_model_objects): """Test getting model by ID when it exists""" - mock_model_objects.__iter__ = MagicMock(return_value=iter([sample_model])) + mock_model_objects.__iter__ = MagicMock(return_value=iter([self.sample_model])) mock_model_objects.__len__ = MagicMock(return_value=1) (result,) = models.get(model_id="test-model") - assert result == sample_model + assert result == self.sample_model @patch("server.api.utils.models.MODEL_OBJECTS") - def test_get_model_by_id_not_found(self, mock_model_objects, sample_model): + def test_get_model_by_id_not_found(self, mock_model_objects): """Test getting model by ID when it doesn't exist""" - mock_model_objects.__iter__ = MagicMock(return_value=iter([sample_model])) + mock_model_objects.__iter__ = MagicMock(return_value=iter([self.sample_model])) mock_model_objects.__len__ = MagicMock(return_value=1) with pytest.raises(UnknownModelError, match="nonexistent not found"): models.get(model_id="nonexistent") - # test_get_model_by_provider: See test/unit/server/api/utils/test_utils_models.py::TestGet::test_get_by_provider - # test_get_model_by_type: See test/unit/server/api/utils/test_utils_models.py::TestGet::test_get_by_type - # test_get_model_exclude_disabled: See test/unit/server/api/utils/test_utils_models.py::TestGet::test_get_exclude_disabled + @patch("server.api.utils.models.MODEL_OBJECTS") + def test_get_model_by_provider(self, mock_model_objects): + """Test filtering models by provider""" + all_models = [self.sample_model, self.disabled_model] + mock_model_objects.__iter__ = MagicMock(return_value=iter(all_models)) + mock_model_objects.__len__ = MagicMock(return_value=len(all_models)) + + (result,) = models.get(model_provider="openai") + + # Since only one model matches provider="openai", it will return a list of single model + assert result == self.sample_model + + @patch("server.api.utils.models.MODEL_OBJECTS") + def test_get_model_by_type(self, mock_model_objects): + """Test filtering models by type""" + all_models = [self.sample_model, self.disabled_model] + mock_model_objects.__iter__ = MagicMock(return_value=iter(all_models)) + mock_model_objects.__len__ = MagicMock(return_value=len(all_models)) + + result = models.get(model_type="ll") + + assert result == all_models + + @patch("server.api.utils.models.MODEL_OBJECTS") + def test_get_model_exclude_disabled(self, mock_model_objects): + """Test excluding disabled models""" + all_models = [self.sample_model, self.disabled_model] + mock_model_objects.__iter__ = MagicMock(return_value=iter(all_models)) + mock_model_objects.__len__ = MagicMock(return_value=len(all_models)) + + (result,) = models.get(include_disabled=False) + assert result == self.sample_model + + @patch("server.api.utils.models.MODEL_OBJECTS", []) + @patch("server.api.utils.models.is_url_accessible") + def test_create_model_success(self, mock_url_check): + """Test successful model creation""" + mock_url_check.return_value = (True, None) + + result = models.create(self.sample_model) + + assert result == self.sample_model + assert result in models.MODEL_OBJECTS + + @patch("server.api.utils.models.MODEL_OBJECTS") + @patch("server.api.utils.models.get") + def test_create_model_already_exists(self, mock_get_model, _mock_model_objects): + """Test creating model that already exists""" + mock_get_model.return_value = self.sample_model # Model already exists - # test_create_model_success: See test/unit/server/api/utils/test_utils_models.py::TestCreate::test_create_success - # test_create_model_already_exists: See test/unit/server/api/utils/test_utils_models.py::TestCreate::test_create_raises_exists_error + with pytest.raises(ExistsModelError, match="Model: openai/test-model already exists"): + models.create(self.sample_model) @patch("server.api.utils.models.MODEL_OBJECTS", []) @patch("server.api.utils.models.is_url_accessible") @@ -94,16 +161,32 @@ def test_create_model_unreachable_url(self, mock_url_check): assert result.enabled is False @patch("server.api.utils.models.MODEL_OBJECTS", []) - def test_create_model_skip_url_check(self, sample_model): + def test_create_model_skip_url_check(self): """Test creating model without URL check""" - result = models.create(sample_model, check_url=False) + result = models.create(self.sample_model, check_url=False) - assert result == sample_model + assert result == self.sample_model assert result in models.MODEL_OBJECTS - # test_delete_model: See test/unit/server/api/utils/test_utils_models.py::TestDelete::test_delete_removes_model - # test_logger_exists: See test/unit/server/api/utils/test_utils_models.py::TestLoggerConfiguration::test_logger_exists - pass + @patch("server.api.utils.models.MODEL_OBJECTS") + def test_delete_model(self, mock_model_objects): + """Test model deletion""" + test_models = [ + Model(id="test-model", provider="openai", type="ll"), + Model(id="other-model", provider="anthropic", type="ll"), + ] + mock_model_objects.__setitem__ = MagicMock() + mock_model_objects.__iter__ = MagicMock(return_value=iter(test_models)) + + models.delete("openai", "test-model") + + # Verify the slice assignment was called + mock_model_objects.__setitem__.assert_called_once() + + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(models, "logger") + assert models.logger.name == "api.utils.models" ##################################################### @@ -112,19 +195,33 @@ def test_create_model_skip_url_check(self, sample_model): class TestModelsUtils: """Test models utility functions""" - @pytest.fixture - def sample_model(self): - """Sample model fixture""" - return Model( + def __init__(self): + """Setup test data""" + self.sample_model = Model( id="test-model", provider="openai", type="ll", enabled=True, api_base="https://api.openai.com" ) + self.sample_oci_config = get_sample_oci_config() - @pytest.fixture - def sample_oci_config(self): - """Sample OCI config fixture""" - return get_sample_oci_config() + @patch("server.api.utils.models.MODEL_OBJECTS", []) + @patch("server.api.utils.models.is_url_accessible") + def test_update_success(self, mock_url_check): + """Test successful model update""" + # First create the model + models.MODEL_OBJECTS.append(self.sample_model) + mock_url_check.return_value = (True, None) - # test_update_success: See test/unit/server/api/utils/test_utils_models.py::TestUpdate::test_update_success + update_payload = Model( + id="test-model", + provider="openai", + type="ll", + enabled=True, + api_base="https://api.openai.com", + temperature=0.8, + ) + + result = models.update(update_payload) + + assert result.temperature == 0.8 @patch("server.api.utils.models.MODEL_OBJECTS", []) @patch("server.api.utils.models.is_url_accessible") @@ -165,10 +262,10 @@ def test_update_embedding_model_max_chunk_size(self, mock_url_check): @patch("server.api.utils.models.MODEL_OBJECTS", []) @patch("server.api.utils.models.is_url_accessible") - def test_update_multiple_fields(self, mock_url_check, sample_model): + def test_update_multiple_fields(self, mock_url_check): """Test updating multiple fields at once""" # Create a model - models.MODEL_OBJECTS.append(sample_model) + models.MODEL_OBJECTS.append(self.sample_model) mock_url_check.return_value = (True, None) # Update multiple fields @@ -189,26 +286,62 @@ def test_update_multiple_fields(self, mock_url_check, sample_model): assert result.temperature == 0.5 assert result.max_tokens == 2048 - # test_get_full_config_success: See test/unit/server/api/utils/test_utils_models.py::TestGetFullConfig::test_get_full_config_success - # test_get_full_config_unknown_model: See test/unit/server/api/utils/test_utils_models.py::TestGetFullConfig::test_get_full_config_raises_unknown_model - # test_get_litellm_config_basic: See test/unit/server/api/utils/test_utils_models.py::TestGetLitellmConfig::test_get_litellm_config_basic + @patch("server.api.utils.models.get") + def test_get_full_config_success(self, mock_get_model): + """Test successful full config retrieval""" + mock_get_model.return_value = [self.sample_model] + model_config = {"model": "openai/gpt-4", "temperature": 0.8} + + full_config, provider = models._get_full_config(model_config, self.sample_oci_config) + + assert provider == "openai" + assert full_config["temperature"] == 0.8 + assert full_config["id"] == "test-model" + mock_get_model.assert_called_once_with(model_provider="openai", model_id="gpt-4", include_disabled=False) + + @patch("server.api.utils.models.get") + def test_get_full_config_unknown_model(self, mock_get_model): + """Test full config retrieval with unknown model""" + mock_get_model.side_effect = UnknownModelError("Model not found") + model_config = {"model": "unknown/model"} + + with pytest.raises(UnknownModelError): + models._get_full_config(model_config, self.sample_oci_config) @patch("server.api.utils.models._get_full_config") @patch("litellm.get_supported_openai_params") - def test_get_litellm_config_cohere(self, mock_get_params, mock_get_full_config, sample_oci_config): + def test_get_litellm_config_basic(self, mock_get_params, mock_get_full_config): + """Test basic LiteLLM config generation""" + mock_get_full_config.return_value = ( + {"temperature": 0.7, "max_tokens": 4096, "api_base": "https://api.openai.com"}, + "openai", + ) + mock_get_params.return_value = ["temperature", "max_tokens"] + model_config = {"model": "openai/gpt-4"} + + result = models.get_litellm_config(model_config, self.sample_oci_config) + + assert result["model"] == "openai/gpt-4" + assert result["temperature"] == 0.7 + assert result["max_tokens"] == 4096 + assert result["drop_params"] is True + + @patch("server.api.utils.models._get_full_config") + @patch("litellm.get_supported_openai_params") + def test_get_litellm_config_cohere(self, mock_get_params, mock_get_full_config): """Test LiteLLM config generation for Cohere""" mock_get_full_config.return_value = ({"api_base": "https://custom.cohere.com/v1"}, "cohere") mock_get_params.return_value = [] model_config = {"model": "cohere/command"} - result = models.get_litellm_config(model_config, sample_oci_config) + result = models.get_litellm_config(model_config, self.sample_oci_config) assert result["api_base"] == "https://api.cohere.ai/compatibility/v1" assert result["model"] == "cohere/command" @patch("server.api.utils.models._get_full_config") @patch("litellm.get_supported_openai_params") - def test_get_litellm_config_xai(self, mock_get_params, mock_get_full_config, sample_oci_config): + def test_get_litellm_config_xai(self, mock_get_params, mock_get_full_config): """Test LiteLLM config generation for xAI""" mock_get_full_config.return_value = ( {"temperature": 0.7, "presence_penalty": 0.1, "frequency_penalty": 0.1}, @@ -217,25 +350,42 @@ def test_get_litellm_config_xai(self, mock_get_params, mock_get_full_config, sam mock_get_params.return_value = ["temperature", "presence_penalty", "frequency_penalty"] model_config = {"model": "xai/grok"} - result = models.get_litellm_config(model_config, sample_oci_config) + result = models.get_litellm_config(model_config, self.sample_oci_config) assert result["temperature"] == 0.7 assert "presence_penalty" not in result assert "frequency_penalty" not in result - # test_get_litellm_config_oci: See test/unit/server/api/utils/test_utils_models.py::TestGetLitellmConfig::test_get_litellm_config_oci_provider + @patch("server.api.utils.models._get_full_config") + @patch("litellm.get_supported_openai_params") + def test_get_litellm_config_oci(self, mock_get_params, mock_get_full_config): + """Test LiteLLM config generation for OCI""" + mock_get_full_config.return_value = ({"temperature": 0.7}, "oci") + mock_get_params.return_value = ["temperature"] + model_config = {"model": "oci/cohere.command"} + + result = models.get_litellm_config(model_config, self.sample_oci_config) + + assert result["oci_user"] == "ocid1.user.oc1..testuser" + assert result["oci_fingerprint"] == "test-fingerprint" + assert result["oci_tenancy"] == "ocid1.tenancy.oc1..testtenant" + assert result["oci_region"] == "us-ashburn-1" + assert result["oci_key_file"] == "/path/to/key.pem" @patch("server.api.utils.models._get_full_config") @patch("litellm.get_supported_openai_params") - def test_get_litellm_config_giskard(self, mock_get_params, mock_get_full_config, sample_oci_config): + def test_get_litellm_config_giskard(self, mock_get_params, mock_get_full_config): """Test LiteLLM config generation for Giskard""" mock_get_full_config.return_value = ({"temperature": 0.7, "model": "test-model"}, "openai") mock_get_params.return_value = ["temperature", "model"] model_config = {"model": "openai/gpt-4"} - result = models.get_litellm_config(model_config, sample_oci_config, giskard=True) + result = models.get_litellm_config(model_config, self.sample_oci_config, giskard=True) assert "model" not in result assert "temperature" not in result - # test_logger_exists: See test/unit/server/api/utils/test_utils_models.py::TestLoggerConfiguration::test_logger_exists + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(models, "logger") + assert models.logger.name == "api.utils.models" diff --git a/tests/server/unit/api/utils/test_utils_oci.py b/tests/server/unit/api/utils/test_utils_oci.py index 39c0a4f1..02c5c217 100644 --- a/tests/server/unit/api/utils/test_utils_oci.py +++ b/tests/server/unit/api/utils/test_utils_oci.py @@ -19,37 +19,61 @@ class TestOciException: """Test custom OCI exception class""" - # test_oci_exception_initialization: See test/unit/server/api/utils/test_utils_oci.py::TestOciException::test_oci_exception_init + def test_oci_exception_initialization(self): + """Test OciException initialization""" + exc = OciException(status_code=400, detail="Invalid configuration") + assert exc.status_code == 400 + assert exc.detail == "Invalid configuration" + assert str(exc) == "Invalid configuration" class TestOciGet: """Test OCI get() function""" - @pytest.fixture - def sample_oci_default(self): - """Sample OCI config with DEFAULT profile""" - return OracleCloudSettings( + def __init__(self): + """Setup test data for all tests""" + self.sample_oci_default = OracleCloudSettings( auth_profile="DEFAULT", compartment_id="ocid1.compartment.oc1..default" ) - - @pytest.fixture - def sample_oci_custom(self): - """Sample OCI config with CUSTOM profile""" - return OracleCloudSettings( + self.sample_oci_custom = OracleCloudSettings( auth_profile="CUSTOM", compartment_id="ocid1.compartment.oc1..custom" ) + self.sample_client_settings = Settings(client="test_client", oci=OciSettings(auth_profile="CUSTOM")) + + @patch("server.bootstrap.bootstrap.OCI_OBJECTS", []) + def test_get_no_objects_configured(self): + """Test getting OCI settings when none are configured""" + with pytest.raises(ValueError, match="not configured"): + oci_utils.get() + + @patch("server.bootstrap.bootstrap.OCI_OBJECTS", new_callable=list) + def test_get_all(self, mock_oci_objects): + """Test getting all OCI settings when no filters are provided""" + all_oci = [self.sample_oci_default, self.sample_oci_custom] + mock_oci_objects.extend(all_oci) + + result = oci_utils.get() + + assert result == all_oci + + @patch("server.bootstrap.bootstrap.OCI_OBJECTS") + def test_get_by_auth_profile_found(self, mock_oci_objects): + """Test getting OCI settings by auth_profile when it exists""" + mock_oci_objects.__iter__ = MagicMock(return_value=iter([self.sample_oci_default, self.sample_oci_custom])) + + result = oci_utils.get(auth_profile="CUSTOM") + + assert result == self.sample_oci_custom - @pytest.fixture - def sample_client_settings(self): - """Sample client settings fixture""" - return Settings(client="test_client", oci=OciSettings(auth_profile="CUSTOM")) + @patch("server.bootstrap.bootstrap.OCI_OBJECTS") + def test_get_by_auth_profile_not_found(self, mock_oci_objects): + """Test getting OCI settings by auth_profile when it doesn't exist""" + mock_oci_objects.__iter__ = MagicMock(return_value=iter([self.sample_oci_default])) - # test_get_no_objects_configured: See test/unit/server/api/utils/test_utils_oci.py::TestGet::test_get_raises_value_error_when_not_configured - # test_get_all: See test/unit/server/api/utils/test_utils_oci.py::TestGet::test_get_returns_all_oci_objects - # test_get_by_auth_profile_found: See test/unit/server/api/utils/test_utils_oci.py::TestGet::test_get_by_auth_profile - # test_get_by_auth_profile_not_found: See test/unit/server/api/utils/test_utils_oci.py::TestGet::test_get_raises_value_error_profile_not_found + with pytest.raises(ValueError, match="profile 'NONEXISTENT' not found"): + oci_utils.get(auth_profile="NONEXISTENT") - def test_get_by_client_with_oci_settings(self, sample_client_settings, sample_oci_default, sample_oci_custom): + def test_get_by_client_with_oci_settings(self): """Test getting OCI settings by client when client has OCI settings""" from server.bootstrap import bootstrap @@ -59,18 +83,18 @@ def test_get_by_client_with_oci_settings(self, sample_client_settings, sample_oc try: # Replace with test data - bootstrap.SETTINGS_OBJECTS = [sample_client_settings] - bootstrap.OCI_OBJECTS = [sample_oci_default, sample_oci_custom] + bootstrap.SETTINGS_OBJECTS = [self.sample_client_settings] + bootstrap.OCI_OBJECTS = [self.sample_oci_default, self.sample_oci_custom] result = oci_utils.get(client="test_client") - assert result == sample_oci_custom + assert result == self.sample_oci_custom finally: # Restore originals bootstrap.SETTINGS_OBJECTS = orig_settings bootstrap.OCI_OBJECTS = orig_oci - def test_get_by_client_without_oci_settings(self, sample_oci_default): + def test_get_by_client_without_oci_settings(self): """Test getting OCI settings by client when client has no OCI settings""" from server.bootstrap import bootstrap @@ -83,19 +107,26 @@ def test_get_by_client_without_oci_settings(self, sample_oci_default): try: # Replace with test data bootstrap.SETTINGS_OBJECTS = [client_settings_no_oci] - bootstrap.OCI_OBJECTS = [sample_oci_default] + bootstrap.OCI_OBJECTS = [self.sample_oci_default] result = oci_utils.get(client="test_client") - assert result == sample_oci_default + assert result == self.sample_oci_default finally: # Restore originals bootstrap.SETTINGS_OBJECTS = orig_settings bootstrap.OCI_OBJECTS = orig_oci - # test_get_by_client_not_found: See test/unit/server/api/utils/test_utils_oci.py::TestGet::test_get_raises_value_error_client_not_found + @patch("server.bootstrap.bootstrap.OCI_OBJECTS") + @patch("server.bootstrap.bootstrap.SETTINGS_OBJECTS") + def test_get_by_client_not_found(self, mock_settings_objects, _mock_oci_objects): + """Test getting OCI settings when client doesn't exist""" + mock_settings_objects.__iter__ = MagicMock(return_value=iter([])) - def test_get_by_client_no_matching_profile(self, sample_client_settings, sample_oci_default): + with pytest.raises(ValueError, match="client test_client not found"): + oci_utils.get(client="test_client") + + def test_get_by_client_no_matching_profile(self): """Test getting OCI settings by client when no matching profile exists""" from server.bootstrap import bootstrap @@ -105,8 +136,8 @@ def test_get_by_client_no_matching_profile(self, sample_client_settings, sample_ try: # Replace with test data - bootstrap.SETTINGS_OBJECTS = [sample_client_settings] - bootstrap.OCI_OBJECTS = [sample_oci_default] # Only DEFAULT profile + bootstrap.SETTINGS_OBJECTS = [self.sample_client_settings] + bootstrap.OCI_OBJECTS = [self.sample_oci_default] # Only DEFAULT profile expected_error = "No settings found for client 'test_client' with auth_profile 'CUSTOM'" with pytest.raises(ValueError, match=expected_error): @@ -116,15 +147,48 @@ def test_get_by_client_no_matching_profile(self, sample_client_settings, sample_ bootstrap.SETTINGS_OBJECTS = orig_settings bootstrap.OCI_OBJECTS = orig_oci - # test_get_both_client_and_auth_profile: See test/unit/server/api/utils/test_utils_oci.py::TestGet::test_get_raises_value_error_both_params + def test_get_both_client_and_auth_profile(self): + """Test that providing both client and auth_profile raises an error""" + with pytest.raises(ValueError, match="provide either 'client' or 'auth_profile', not both"): + oci_utils.get(client="test_client", auth_profile="CUSTOM") class TestGetSigner: """Test get_signer() function""" - # test_get_signer_instance_principal: See test/unit/server/api/utils/test_utils_oci.py::TestGetSigner::test_get_signer_instance_principal - # test_get_signer_oke_workload_identity: See test/unit/server/api/utils/test_utils_oci.py::TestGetSigner::test_get_signer_oke_workload_identity - # test_get_signer_api_key: See test/unit/server/api/utils/test_utils_oci.py::TestGetSigner::test_get_signer_api_key_returns_none + def test_get_signer_instance_principal(self): + """Test get_signer with instance_principal authentication""" + config = OracleCloudSettings(auth_profile="DEFAULT", authentication="instance_principal") + + with patch("oci.auth.signers.InstancePrincipalsSecurityTokenSigner") as mock_signer: + mock_instance = MagicMock() + mock_signer.return_value = mock_instance + + result = oci_utils.get_signer(config) + + assert result == mock_instance + mock_signer.assert_called_once() + + def test_get_signer_oke_workload_identity(self): + """Test get_signer with oke_workload_identity authentication""" + config = OracleCloudSettings(auth_profile="DEFAULT", authentication="oke_workload_identity") + + with patch("oci.auth.signers.get_oke_workload_identity_resource_principal_signer") as mock_signer: + mock_instance = MagicMock() + mock_signer.return_value = mock_instance + + result = oci_utils.get_signer(config) + + assert result == mock_instance + mock_signer.assert_called_once() + + def test_get_signer_api_key(self): + """Test get_signer with api_key authentication (returns None)""" + config = OracleCloudSettings(auth_profile="DEFAULT", authentication="api_key") + + result = oci_utils.get_signer(config) + + assert result is None def test_get_signer_security_token(self): """Test get_signer with security_token authentication (returns None)""" @@ -138,10 +202,9 @@ def test_get_signer_security_token(self): class TestInitClient: """Test init_client() function""" - @pytest.fixture - def api_key_config(self): - """API key configuration fixture""" - return OracleCloudSettings( + def __init__(self): + """Setup test data""" + self.api_key_config = OracleCloudSettings( auth_profile="DEFAULT", authentication="api_key", region="us-ashburn-1", @@ -151,13 +214,24 @@ def api_key_config(self): key_file="/path/to/key.pem", ) - # test_init_client_api_key: See test/unit/server/api/utils/test_utils_oci.py::TestInitClient::test_init_client_standard_auth + @patch("oci.object_storage.ObjectStorageClient") + @patch.object(oci_utils, "get_signer", return_value=None) + def test_init_client_api_key(self, mock_get_signer, mock_client_class): + """Test init_client with API key authentication""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + result = oci_utils.init_client(oci.object_storage.ObjectStorageClient, self.api_key_config) + + assert result == mock_client + mock_get_signer.assert_called_once_with(self.api_key_config) + mock_client_class.assert_called_once() @patch("oci.generative_ai_inference.GenerativeAiInferenceClient") @patch.object(oci_utils, "get_signer", return_value=None) - def test_init_client_genai_with_endpoint(self, _mock_get_signer, mock_client_class, api_key_config): + def test_init_client_genai_with_endpoint(self, _mock_get_signer, mock_client_class): """Test init_client for GenAI sets correct service endpoint""" - genai_config = api_key_config.model_copy() + genai_config = self.api_key_config.model_copy() genai_config.genai_compartment_id = "ocid1.compartment.oc1..test" genai_config.genai_region = "us-chicago-1" @@ -265,73 +339,149 @@ def test_init_client_with_security_token( mock_load_key.assert_called_once_with("/path/to/key.pem") mock_sec_token_signer.assert_called_once_with("mock_token_content", mock_private_key) - # test_init_client_invalid_config: See test/unit/server/api/utils/test_utils_oci.py::TestInitClient::test_init_client_raises_oci_exception_on_invalid_config + @patch("oci.object_storage.ObjectStorageClient") + @patch.object(oci_utils, "get_signer", return_value=None) + def test_init_client_invalid_config(self, _mock_get_signer, mock_client_class): + """Test init_client with invalid config raises OciException""" + mock_client_class.side_effect = oci.exceptions.InvalidConfig("Bad config") + + with pytest.raises(OciException) as exc_info: + oci_utils.init_client(oci.object_storage.ObjectStorageClient, self.api_key_config) + + assert exc_info.value.status_code == 400 + assert "Invalid Config" in str(exc_info.value) class TestOciUtils: """Test OCI utility functions""" - @pytest.fixture - def sample_oci_config(self): - """Sample OCI config fixture""" - return get_sample_oci_config() + def __init__(self): + """Setup test data""" + self.sample_oci_config = get_sample_oci_config() + + def test_init_genai_client(self): + """Test GenAI client initialization""" + with patch.object(oci_utils, "init_client") as mock_init_client: + mock_client = MagicMock() + mock_init_client.return_value = mock_client - # test_init_genai_client: See test/unit/server/api/utils/test_utils_oci.py::TestInitGenaiClient::test_init_genai_client_calls_init_client - # test_get_namespace_success: See test/unit/server/api/utils/test_utils_oci.py::TestGetNamespace::test_get_namespace_success + result = oci_utils.init_genai_client(self.sample_oci_config) + + assert result == mock_client + mock_init_client.assert_called_once_with( + oci.generative_ai_inference.GenerativeAiInferenceClient, self.sample_oci_config + ) @patch.object(oci_utils, "init_client") - def test_get_namespace_invalid_config(self, mock_init_client, sample_oci_config): + def test_get_namespace_success(self, mock_init_client): + """Test successful namespace retrieval""" + mock_client = MagicMock() + mock_client.get_namespace.return_value.data = "test-namespace" + mock_init_client.return_value = mock_client + + result = oci_utils.get_namespace(self.sample_oci_config) + + assert result == "test-namespace" + assert self.sample_oci_config.namespace == "test-namespace" + + @patch.object(oci_utils, "init_client") + def test_get_namespace_invalid_config(self, mock_init_client): """Test namespace retrieval with invalid config""" mock_client = MagicMock() mock_client.get_namespace.side_effect = oci.exceptions.InvalidConfig("Invalid config") mock_init_client.return_value = mock_client with pytest.raises(OciException) as exc_info: - oci_utils.get_namespace(sample_oci_config) + oci_utils.get_namespace(self.sample_oci_config) assert exc_info.value.status_code == 400 assert "Invalid Config" in str(exc_info.value) - # test_get_namespace_file_not_found: See test/unit/server/api/utils/test_utils_oci.py::TestGetNamespace::test_get_namespace_raises_on_file_not_found - # test_get_namespace_service_error: See test/unit/server/api/utils/test_utils_oci.py::TestGetNamespace::test_get_namespace_raises_on_service_error + @patch.object(oci_utils, "init_client") + def test_get_namespace_file_not_found(self, mock_init_client): + """Test namespace retrieval with file not found error""" + mock_init_client.side_effect = FileNotFoundError("Key file not found") + + with pytest.raises(OciException) as exc_info: + oci_utils.get_namespace(self.sample_oci_config) + + assert exc_info.value.status_code == 400 + assert "Invalid Key Path" in str(exc_info.value) @patch.object(oci_utils, "init_client") - def test_get_namespace_unbound_local_error(self, mock_init_client, sample_oci_config): + def test_get_namespace_service_error(self, mock_init_client): + """Test namespace retrieval with service error""" + mock_client = MagicMock() + mock_client.get_namespace.side_effect = oci.exceptions.ServiceError( + status=401, code="NotAuthenticated", headers={}, message="Auth failed" + ) + mock_init_client.return_value = mock_client + + with pytest.raises(OciException) as exc_info: + oci_utils.get_namespace(self.sample_oci_config) + + assert exc_info.value.status_code == 401 + assert "AuthN Error" in str(exc_info.value) + + @patch.object(oci_utils, "init_client") + def test_get_namespace_unbound_local_error(self, mock_init_client): """Test namespace retrieval with unbound local error""" mock_client = MagicMock() mock_client.get_namespace.side_effect = UnboundLocalError("local variable referenced before assignment") mock_init_client.return_value = mock_client with pytest.raises(OciException) as exc_info: - oci_utils.get_namespace(sample_oci_config) + oci_utils.get_namespace(self.sample_oci_config) assert exc_info.value.status_code == 500 assert "No Configuration" in str(exc_info.value) @patch.object(oci_utils, "init_client") - def test_get_namespace_request_exception(self, mock_init_client, sample_oci_config): + def test_get_namespace_request_exception(self, mock_init_client): """Test namespace retrieval with request exception""" mock_client = MagicMock() mock_client.get_namespace.side_effect = oci.exceptions.RequestException("Connection timeout") mock_init_client.return_value = mock_client with pytest.raises(OciException) as exc_info: - oci_utils.get_namespace(sample_oci_config) + oci_utils.get_namespace(self.sample_oci_config) assert exc_info.value.status_code == 503 @patch.object(oci_utils, "init_client") - def test_get_namespace_generic_exception(self, mock_init_client, sample_oci_config): + def test_get_namespace_generic_exception(self, mock_init_client): """Test namespace retrieval with generic exception""" mock_client = MagicMock() mock_client.get_namespace.side_effect = Exception("Unexpected error") mock_init_client.return_value = mock_client with pytest.raises(OciException) as exc_info: - oci_utils.get_namespace(sample_oci_config) + oci_utils.get_namespace(self.sample_oci_config) assert exc_info.value.status_code == 500 assert "Unexpected error" in str(exc_info.value) - # test_get_regions_success: See test/unit/server/api/utils/test_utils_oci.py::TestGetRegions::test_get_regions_returns_list - # test_logger_exists: See test/unit/server/api/utils/test_utils_oci.py::TestLoggerConfiguration::test_logger_exists + @patch.object(oci_utils, "init_client") + def test_get_regions_success(self, mock_init_client): + """Test successful regions retrieval""" + mock_client = MagicMock() + mock_region = MagicMock() + mock_region.is_home_region = True + mock_region.region_key = "IAD" + mock_region.region_name = "us-ashburn-1" + mock_region.status = "READY" + mock_client.list_region_subscriptions.return_value.data = [mock_region] + mock_init_client.return_value = mock_client + + result = oci_utils.get_regions(self.sample_oci_config) + + assert len(result) == 1 + assert result[0]["is_home_region"] is True + assert result[0]["region_key"] == "IAD" + assert result[0]["region_name"] == "us-ashburn-1" + assert result[0]["status"] == "READY" + + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(oci_utils, "logger") + assert oci_utils.logger.name == "api.utils.oci" diff --git a/tests/server/unit/api/utils/test_utils_oci_refresh.py b/tests/server/unit/api/utils/test_utils_oci_refresh.py index 72b81920..7857c306 100644 --- a/tests/server/unit/api/utils/test_utils_oci_refresh.py +++ b/tests/server/unit/api/utils/test_utils_oci_refresh.py @@ -8,8 +8,6 @@ from datetime import datetime from unittest.mock import patch, MagicMock -import pytest - from server.api.utils import oci as oci_utils from common.schema import OracleCloudSettings @@ -17,10 +15,9 @@ class TestGetBucketObjectsWithMetadata: """Test get_bucket_objects_with_metadata() function""" - @pytest.fixture - def sample_oci_config(self): - """Sample OCI config fixture""" - return OracleCloudSettings( + def __init__(self): + """Setup test data""" + self.sample_oci_config = OracleCloudSettings( auth_profile="DEFAULT", namespace="test-namespace", compartment_id="ocid1.compartment.oc1..test", @@ -38,7 +35,7 @@ def create_mock_object(self, name, size, etag, time_modified, md5): return mock_obj @patch.object(oci_utils, "init_client") - def test_get_bucket_objects_with_metadata_success(self, mock_init_client, sample_oci_config): + def test_get_bucket_objects_with_metadata_success(self, mock_init_client): """Test successful retrieval of bucket objects with metadata""" # Create mock objects time1 = datetime(2025, 11, 1, 10, 0, 0) @@ -59,7 +56,7 @@ def test_get_bucket_objects_with_metadata_success(self, mock_init_client, sample mock_init_client.return_value = mock_client # Execute - result = oci_utils.get_bucket_objects_with_metadata("test-bucket", sample_oci_config) + result = oci_utils.get_bucket_objects_with_metadata("test-bucket", self.sample_oci_config) # Verify assert len(result) == 2 @@ -81,7 +78,7 @@ def test_get_bucket_objects_with_metadata_success(self, mock_init_client, sample assert "etag" in call_kwargs["fields"] @patch.object(oci_utils, "init_client") - def test_get_bucket_objects_filters_unsupported_types(self, mock_init_client, sample_oci_config): + def test_get_bucket_objects_filters_unsupported_types(self, mock_init_client): """Test that unsupported file types are filtered out""" # Create mock objects with various file types mock_pdf = self.create_mock_object("doc.pdf", 1000, "etag1", datetime.now(), "md5-1") @@ -97,7 +94,7 @@ def test_get_bucket_objects_filters_unsupported_types(self, mock_init_client, sa mock_init_client.return_value = mock_client # Execute - result = oci_utils.get_bucket_objects_with_metadata("test-bucket", sample_oci_config) + result = oci_utils.get_bucket_objects_with_metadata("test-bucket", self.sample_oci_config) # Verify only supported types are included assert len(result) == 2 @@ -108,7 +105,7 @@ def test_get_bucket_objects_filters_unsupported_types(self, mock_init_client, sa assert "archive.zip" not in names @patch.object(oci_utils, "init_client") - def test_get_bucket_objects_empty_bucket(self, mock_init_client, sample_oci_config): + def test_get_bucket_objects_empty_bucket(self, mock_init_client): """Test handling of empty bucket""" # Mock empty bucket mock_client = MagicMock() @@ -118,13 +115,13 @@ def test_get_bucket_objects_empty_bucket(self, mock_init_client, sample_oci_conf mock_init_client.return_value = mock_client # Execute - result = oci_utils.get_bucket_objects_with_metadata("empty-bucket", sample_oci_config) + result = oci_utils.get_bucket_objects_with_metadata("empty-bucket", self.sample_oci_config) # Verify assert len(result) == 0 @patch.object(oci_utils, "init_client") - def test_get_bucket_objects_none_time_modified(self, mock_init_client, sample_oci_config): + def test_get_bucket_objects_none_time_modified(self, mock_init_client): """Test handling of objects with None time_modified""" # Create mock object with None time_modified mock_obj = self.create_mock_object( @@ -139,7 +136,7 @@ def test_get_bucket_objects_none_time_modified(self, mock_init_client, sample_oc mock_init_client.return_value = mock_client # Execute - result = oci_utils.get_bucket_objects_with_metadata("test-bucket", sample_oci_config) + result = oci_utils.get_bucket_objects_with_metadata("test-bucket", self.sample_oci_config) # Verify time_modified is None assert len(result) == 1 diff --git a/tests/server/unit/api/utils/test_utils_settings.py b/tests/server/unit/api/utils/test_utils_settings.py index 3027874f..aebff4d0 100644 --- a/tests/server/unit/api/utils/test_utils_settings.py +++ b/tests/server/unit/api/utils/test_utils_settings.py @@ -44,12 +44,67 @@ def make_sample_config_data(): class TestClientSettings: """Test client settings CRUD operations""" - # test_create_client_success: See test/unit/server/api/utils/test_utils_settings.py::TestCreateClient::test_create_client_success - # test_create_client_already_exists: See test/unit/server/api/utils/test_utils_settings.py::TestCreateClient::test_create_client_raises_on_existing - # test_get_client_found: See test/unit/server/api/utils/test_utils_settings.py::TestGetClient::test_get_client_success - # test_get_client_not_found: See test/unit/server/api/utils/test_utils_settings.py::TestGetClient::test_get_client_raises_on_not_found - # test_update_client: See test/unit/server/api/utils/test_utils_settings.py::TestUpdateClient::test_update_client_success - pass + @patch("server.api.utils.settings.bootstrap") + def test_create_client_success(self, mock_bootstrap): + """Test successful client settings creation""" + default_cfg = make_default_settings() + settings_list = [default_cfg] + mock_bootstrap.SETTINGS_OBJECTS = settings_list + + result = settings.create_client("new_client") + + assert result.client == "new_client" + # Verify ll_model settings are copied from default + result_ll_model = result.model_dump()["ll_model"] + default_ll_model = default_cfg.model_dump()["ll_model"] + assert result_ll_model["max_tokens"] == default_ll_model["max_tokens"] + assert len(settings_list) == 2 + assert settings_list[-1].client == "new_client" + + @patch("server.api.utils.settings.bootstrap.SETTINGS_OBJECTS") + def test_create_client_already_exists(self, mock_settings_objects): + """Test creating client settings when client already exists""" + test_cfg = make_test_client_settings() + mock_settings_objects.__iter__ = MagicMock(return_value=iter([test_cfg])) + + with pytest.raises(ValueError, match="client test_client already exists"): + settings.create_client("test_client") + + @patch("server.api.utils.settings.bootstrap.SETTINGS_OBJECTS") + def test_get_client_found(self, mock_settings_objects): + """Test getting client settings when client exists""" + test_cfg = make_test_client_settings() + mock_settings_objects.__iter__ = MagicMock(return_value=iter([test_cfg])) + + result = settings.get_client("test_client") + + assert result == test_cfg + + @patch("server.api.utils.settings.bootstrap.SETTINGS_OBJECTS") + def test_get_client_not_found(self, mock_settings_objects): + """Test getting client settings when client doesn't exist""" + default_cfg = make_default_settings() + mock_settings_objects.__iter__ = MagicMock(return_value=iter([default_cfg])) + + with pytest.raises(ValueError, match="client nonexistent not found"): + settings.get_client("nonexistent") + + @patch("server.api.utils.settings.bootstrap.SETTINGS_OBJECTS") + @patch("server.api.utils.settings.get_client") + def test_update_client(self, mock_get_settings, mock_settings_objects): + """Test updating client settings""" + test_cfg = make_test_client_settings() + mock_get_settings.return_value = test_cfg + mock_settings_objects.remove = MagicMock() + mock_settings_objects.append = MagicMock() + mock_settings_objects.__iter__ = MagicMock(return_value=iter([test_cfg])) + + new_settings = Settings(client="test_client", max_tokens=800, temperature=0.9) + result = settings.update_client(new_settings, "test_client") + + assert result.client == "test_client" + mock_settings_objects.remove.assert_called_once_with(test_cfg) + mock_settings_objects.append.assert_called_once() ##################################################### @@ -58,8 +113,39 @@ class TestClientSettings: class TestServerConfiguration: """Test server configuration operations""" - # test_get_server: See test/unit/server/api/utils/test_utils_settings.py::TestGetServer::test_get_server_returns_config - # test_update_server: See test/unit/server/api/utils/test_utils_settings.py::TestUpdateServer::test_update_server_updates_databases + @pytest.mark.asyncio + @patch("server.api.utils.settings.get_mcp_prompts_with_overrides") + @patch("server.api.utils.settings.bootstrap.DATABASE_OBJECTS") + @patch("server.api.utils.settings.bootstrap.MODEL_OBJECTS") + @patch("server.api.utils.settings.bootstrap.OCI_OBJECTS") + async def test_get_server(self, mock_oci, mock_models, mock_databases, mock_get_prompts): + """Test getting server configuration""" + mock_databases.__iter__ = MagicMock( + return_value=iter([Database(name="test", user="u", password="p", dsn="d")]) + ) + mock_models.__iter__ = MagicMock(return_value=iter([Model(id="test", provider="openai", type="ll")])) + mock_oci.__iter__ = MagicMock(return_value=iter([OracleCloudSettings(auth_profile="DEFAULT")])) + mock_get_prompts.return_value = [] + + mock_mcp_engine = MagicMock() + result = await settings.get_server(mock_mcp_engine) + + assert "database_configs" in result + assert "model_configs" in result + assert "oci_configs" in result + assert "prompt_configs" in result + + @patch("server.api.utils.settings.bootstrap") + def test_update_server(self, mock_bootstrap): + """Test updating server configuration""" + mock_bootstrap.DATABASE_OBJECTS = [] + mock_bootstrap.MODEL_OBJECTS = [] + mock_bootstrap.OCI_OBJECTS = [] + + settings.update_server(make_sample_config_data()) + + assert hasattr(mock_bootstrap, "DATABASE_OBJECTS") + assert hasattr(mock_bootstrap, "MODEL_OBJECTS") @patch("server.api.utils.settings.bootstrap") def test_update_server_mutates_lists_not_replaces(self, mock_bootstrap): @@ -99,14 +185,64 @@ def test_update_server_mutates_lists_not_replaces(self, mock_bootstrap): class TestConfigLoading: """Test configuration loading operations""" - # test_load_config_from_json_data_with_client: See test/unit/server/api/utils/test_utils_settings.py::TestLoadConfigFromJsonData::test_load_config_from_json_data_with_client - # test_load_config_from_json_data_without_client: See test/unit/server/api/utils/test_utils_settings.py::TestLoadConfigFromJsonData::test_load_config_from_json_data_without_client - # test_load_config_from_json_data_missing_client_settings: See test/unit/server/api/utils/test_utils_settings.py::TestLoadConfigFromJsonData::test_load_config_from_json_data_raises_missing_settings - # test_read_config_from_json_file_success: See test/unit/server/api/utils/test_utils_settings.py::TestReadConfigFromJsonFile::test_read_config_from_json_file_success - # test_read_config_from_json_file_not_exists: Empty test stub - not implemented - # test_read_config_from_json_file_wrong_extension: Empty test stub - not implemented - # test_logger_exists: See test/unit/server/api/utils/test_utils_settings.py::TestLoggerConfiguration::test_logger_exists - pass + @patch("server.api.utils.settings.update_server") + @patch("server.api.utils.settings.update_client") + def test_load_config_from_json_data_with_client(self, mock_update_client, mock_update_server): + """Test loading config from JSON data with specific client""" + config_data = make_sample_config_data() + settings.load_config_from_json_data(config_data, client="test_client") + + mock_update_server.assert_called_once_with(config_data) + mock_update_client.assert_called_once() + + @patch("server.api.utils.settings.update_server") + @patch("server.api.utils.settings.update_client") + def test_load_config_from_json_data_without_client(self, mock_update_client, mock_update_server): + """Test loading config from JSON data without specific client""" + config_data = make_sample_config_data() + settings.load_config_from_json_data(config_data) + + mock_update_server.assert_called_once_with(config_data) + assert mock_update_client.call_count == 2 + + @patch("server.api.utils.settings.update_server") + def test_load_config_from_json_data_missing_client_settings(self, _mock_update_server): + """Test loading config from JSON data without client_settings""" + invalid_config = {"database_configs": [], "model_configs": [], "oci_configs": [], "prompt_configs": []} + + with pytest.raises(KeyError, match="Missing client_settings in config file"): + settings.load_config_from_json_data(invalid_config) + + @patch.dict(os.environ, {"CONFIG_FILE": "/path/to/config.json"}) + @patch("os.path.isfile") + @patch("os.access") + @patch("builtins.open", mock_open(read_data='{"test": "data"}')) + @patch("json.load") + def test_read_config_from_json_file_success(self, mock_json_load, mock_access, mock_isfile): + """Test successful reading of config file""" + mock_isfile.return_value = True + mock_access.return_value = True + mock_json_load.return_value = make_sample_config_data() + + result = settings.read_config_from_json_file() + + assert isinstance(result, Configuration) + mock_json_load.assert_called_once() + + @patch.dict(os.environ, {"CONFIG_FILE": "/path/to/nonexistent.json"}) + @patch("os.path.isfile") + def test_read_config_from_json_file_not_exists(self, mock_isfile): + """Test reading config file that doesn't exist""" + mock_isfile.return_value = False + + @patch.dict(os.environ, {"CONFIG_FILE": "/path/to/config.txt"}) + def test_read_config_from_json_file_wrong_extension(self): + """Test reading config file with wrong extension""" + + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(settings, "logger") + assert settings.logger.name == "api.core.settings" ##################################################### @@ -115,8 +251,25 @@ class TestConfigLoading: class TestPromptOverrides: """Test prompt override operations""" - # test_load_prompt_override_with_text: See test/unit/server/api/utils/test_utils_settings.py::TestLoadPromptOverride::test_load_prompt_override_with_text - # test_load_prompt_override_without_text: See test/unit/server/api/utils/test_utils_settings.py::TestLoadPromptOverride::test_load_prompt_override_without_text + @patch("server.api.utils.settings.cache") + def test_load_prompt_override_with_text(self, mock_cache): + """Test loading prompt override when text is provided""" + prompt = {"name": "optimizer_test-prompt", "text": "You are a test assistant"} + + result = settings._load_prompt_override(prompt) + + assert result is True + mock_cache.set_override.assert_called_once_with("optimizer_test-prompt", "You are a test assistant") + + @patch("server.api.utils.settings.cache") + def test_load_prompt_override_without_text(self, mock_cache): + """Test loading prompt override when text is not provided""" + prompt = {"name": "optimizer_test-prompt"} + + result = settings._load_prompt_override(prompt) + + assert result is False + mock_cache.set_override.assert_not_called() @patch("server.api.utils.settings.cache") def test_load_prompt_override_empty_text(self, mock_cache): @@ -128,6 +281,36 @@ def test_load_prompt_override_empty_text(self, mock_cache): assert result is False mock_cache.set_override.assert_not_called() - # test_load_prompt_configs_success: See test/unit/server/api/utils/test_utils_settings.py::TestLoadPromptConfigs::test_load_prompt_configs_with_prompts - # test_load_prompt_configs_no_prompts_key: See test/unit/server/api/utils/test_utils_settings.py::TestLoadPromptConfigs::test_load_prompt_configs_without_key - # test_load_prompt_configs_empty_list: See test/unit/server/api/utils/test_utils_settings.py::TestLoadPromptConfigs::test_load_prompt_configs_empty_list + @patch("server.api.utils.settings._load_prompt_override") + def test_load_prompt_configs_success(self, mock_load_override): + """Test loading prompt configs successfully""" + mock_load_override.side_effect = [True, True, False] + config_data = { + "prompt_configs": [ + {"name": "prompt1", "text": "text1"}, + {"name": "prompt2", "text": "text2"}, + {"name": "prompt3", "text": "text3"}, + ] + } + + settings._load_prompt_configs(config_data) + + assert mock_load_override.call_count == 3 + + @patch("server.api.utils.settings._load_prompt_override") + def test_load_prompt_configs_no_prompts_key(self, mock_load_override): + """Test loading prompt configs when key is missing""" + config_data = {"other_configs": []} + + settings._load_prompt_configs(config_data) + + mock_load_override.assert_not_called() + + @patch("server.api.utils.settings._load_prompt_override") + def test_load_prompt_configs_empty_list(self, mock_load_override): + """Test loading prompt configs with empty list""" + config_data = {"prompt_configs": []} + + settings._load_prompt_configs(config_data) + + mock_load_override.assert_not_called() diff --git a/tests/server/unit/api/utils/test_utils_testbed.py b/tests/server/unit/api/utils/test_utils_testbed.py index 7137d4a3..f99dbbdc 100644 --- a/tests/server/unit/api/utils/test_utils_testbed.py +++ b/tests/server/unit/api/utils/test_utils_testbed.py @@ -17,15 +17,10 @@ class TestTestbedUtils: """Test testbed utility functions""" - @pytest.fixture - def mock_connection(self): - """Mock database connection fixture""" - return MagicMock(spec=Connection) - - @pytest.fixture - def sample_qa_data(self): - """Sample QA data fixture""" - return { + def __init__(self): + """Setup test data""" + self.mock_connection = MagicMock(spec=Connection) + self.sample_qa_data = { "question": "What is the capital of France?", "answer": "Paris", "context": "France is a country in Europe.", @@ -78,11 +73,11 @@ def test_jsonl_to_json_content_whitespace_content(self): testbed.jsonl_to_json_content(content) @patch("server.api.utils.databases.execute_sql") - def test_create_testset_objects(self, mock_execute_sql, mock_connection): + def test_create_testset_objects(self, mock_execute_sql): """Test creating testset database objects""" mock_execute_sql.return_value = [] - testbed.create_testset_objects(mock_connection) + testbed.create_testset_objects(self.mock_connection) # Should execute 3 SQL statements (testsets, testset_qa, evaluations tables) assert mock_execute_sql.call_count == 3 diff --git a/tests/server/unit/bootstrap/test_bootstrap.py b/tests/server/unit/bootstrap/test_bootstrap.py index 9f28e5ed..9caedd01 100644 --- a/tests/server/unit/bootstrap/test_bootstrap.py +++ b/tests/server/unit/bootstrap/test_bootstrap.py @@ -3,20 +3,48 @@ Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ # spell-checker: disable +# pylint: disable=protected-access import-error import-outside-toplevel -# ============================================================================= -# DEPRECATED: Tests in this file have been replaced by more comprehensive tests -# in test/unit/server/bootstrap/test_bootstrap_bootstrap.py -# ============================================================================= -# -# test_module_imports_and_initialization -> Replaced by: -# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestBootstrapModule::test_calls_all_bootstrap_functions -# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestBootstrapModule::test_database_objects_is_list -# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestBootstrapModule::test_model_objects_is_list -# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestBootstrapModule::test_oci_objects_is_list -# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestBootstrapModule::test_settings_objects_is_list -# -# test_logger_exists -> Replaced by: -# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestLoggerConfiguration::test_logger_exists -# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestLoggerConfiguration::test_logger_name -# +import importlib +from unittest.mock import patch, MagicMock + +from server.bootstrap import bootstrap + + +class TestBootstrap: + """Test bootstrap module functionality""" + + @patch("server.bootstrap.databases.main") + @patch("server.bootstrap.models.main") + @patch("server.bootstrap.oci.main") + @patch("server.bootstrap.settings.main") + def test_module_imports_and_initialization( + self, mock_settings, mock_oci, mock_models, mock_databases + ): + """Test that all bootstrap objects are properly initialized""" + # Mock return values + mock_databases.return_value = [MagicMock()] + mock_models.return_value = [MagicMock()] + mock_oci.return_value = [MagicMock()] + mock_settings.return_value = [MagicMock()] + + # Reload the module to trigger initialization + + importlib.reload(bootstrap) + + # Verify all bootstrap functions were called + mock_databases.assert_called_once() + mock_models.assert_called_once() + mock_oci.assert_called_once() + mock_settings.assert_called_once() + + # Verify objects are created + assert hasattr(bootstrap, "DATABASE_OBJECTS") + assert hasattr(bootstrap, "MODEL_OBJECTS") + assert hasattr(bootstrap, "OCI_OBJECTS") + assert hasattr(bootstrap, "SETTINGS_OBJECTS") + + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(bootstrap, "logger") + assert bootstrap.logger.name == "bootstrap" From 18d51fe28f02eb4b375c1ee30ee640ecef7c505e Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 21:05:35 +0000 Subject: [PATCH 17/20] Merge --- src/client/content/config/tabs/mcp.py | 6 ++-- src/client/content/tools/tabs/split_embed.py | 38 ++++++++++++++++---- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/src/client/content/config/tabs/mcp.py b/src/client/content/config/tabs/mcp.py index 32760515..484de2b9 100644 --- a/src/client/content/config/tabs/mcp.py +++ b/src/client/content/config/tabs/mcp.py @@ -123,19 +123,17 @@ def render_configs(mcp_server: str, mcp_type: str, configs: list) -> None: col1.markdown("Name", unsafe_allow_html=True) col2.markdown("​") for mcp_name in configs: - # The key prefix is to give each widget a unique key in the loop; the key itself is never used - key_prefix = f"{mcp_server}_{mcp_type}_{mcp_name}" col1.text_input( "Name", value=mcp_name, label_visibility="collapsed", disabled=True, - key=f"{key_prefix}_name", + key=f"{mcp_server}_{mcp_type}_{mcp_name}_input", ) col2.button( "Details", on_click=mcp_details, - key=f"{key_prefix}_details", + key=f"{mcp_server}_{mcp_type}_{mcp_name}_details", kwargs={"mcp_server": mcp_server, "mcp_type": mcp_type, "mcp_name": mcp_name}, ) diff --git a/src/client/content/tools/tabs/split_embed.py b/src/client/content/tools/tabs/split_embed.py index ff44f0ee..b7e3f6b5 100644 --- a/src/client/content/tools/tabs/split_embed.py +++ b/src/client/content/tools/tabs/split_embed.py @@ -123,23 +123,49 @@ def files_data_editor(files, key): def update_chunk_overlap_slider() -> None: - """Keep text and slider input aligned""" - state.selected_chunk_overlap_slider = state.selected_chunk_overlap_input + """Keep text and slider input aligned and ensure overlap doesn't exceed chunk size""" + new_overlap = state.selected_chunk_overlap_input + # Ensure overlap doesn't exceed chunk size + if hasattr(state, "selected_chunk_size_slider"): + chunk_size = state.selected_chunk_size_slider + if new_overlap >= chunk_size: + new_overlap = max(0, chunk_size - 1) + state.selected_chunk_overlap_input = new_overlap + state.selected_chunk_overlap_slider = new_overlap def update_chunk_overlap_input() -> None: - """Keep text and slider input aligned""" - state.selected_chunk_overlap_input = state.selected_chunk_overlap_slider + """Keep text and slider input aligned and ensure overlap doesn't exceed chunk size""" + new_overlap = state.selected_chunk_overlap_slider + # Ensure overlap doesn't exceed chunk size + if hasattr(state, "selected_chunk_size_slider"): + chunk_size = state.selected_chunk_size_slider + if new_overlap >= chunk_size: + new_overlap = max(0, chunk_size - 1) + state.selected_chunk_overlap_slider = new_overlap + state.selected_chunk_overlap_input = new_overlap def update_chunk_size_slider() -> None: - """Keep text and slider input aligned""" + """Keep text and slider input aligned and adjust overlap if needed""" state.selected_chunk_size_slider = state.selected_chunk_size_input + # If overlap exceeds new chunk size, cap it + if hasattr(state, "selected_chunk_overlap_slider"): + if state.selected_chunk_overlap_slider >= state.selected_chunk_size_slider: + new_overlap = max(0, state.selected_chunk_size_slider - 1) + state.selected_chunk_overlap_slider = new_overlap + state.selected_chunk_overlap_input = new_overlap def update_chunk_size_input() -> None: - """Keep text and slider input aligned""" + """Keep text and slider input aligned and adjust overlap if needed""" state.selected_chunk_size_input = state.selected_chunk_size_slider + # If overlap exceeds new chunk size, cap it + if hasattr(state, "selected_chunk_overlap_input"): + if state.selected_chunk_overlap_input >= state.selected_chunk_size_input: + new_overlap = max(0, state.selected_chunk_size_input - 1) + state.selected_chunk_overlap_input = new_overlap + state.selected_chunk_overlap_slider = new_overlap ############################################################################# From 8559128de165dfc743245eaf2f8cd90ebfec339c Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 21:32:15 +0000 Subject: [PATCH 18/20] Update tests --- .pylintrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pylintrc b/.pylintrc index e3218d31..0484bbf7 100644 --- a/.pylintrc +++ b/.pylintrc @@ -52,7 +52,7 @@ ignore=CVS,.venv # ignore-list. The regex matches against paths and can be in Posix or Windows # format. Because '\\' represents the directory delimiter on Windows systems, # it can't be used as an escape character. -ignore-paths=.*[/\\]wip[/\\].*,src/client/mcp,docs/themes/relearn,docs/public,docs/static/demoware +ignore-paths=.*[/\\]wip[/\\].*,src/client/mcp,docs/themes/relearn,docs/public,docs/static/demoware,src/server/agents/chatbot.py # Files or directories matching the regular expression patterns are skipped. # The regex matches against base names, not paths. The default value ignores From 625482d8fb54d48bad8a6c1e308369fe485c04dc Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 21:32:27 +0000 Subject: [PATCH 19/20] Update Tests --- .../content/tools/tabs/test_split_embed.py | 9 +- .../integration/utils/test_st_common.py | 336 +----------------- .../client/unit/content/test_chatbot_unit.py | 43 ++- .../client/unit/utils/test_st_common_unit.py | 291 +-------------- .../integration/test_endpoints_settings.py | 5 +- 5 files changed, 60 insertions(+), 624 deletions(-) diff --git a/tests/client/integration/content/tools/tabs/test_split_embed.py b/tests/client/integration/content/tools/tabs/test_split_embed.py index 7fdd7b47..8bd9cb12 100644 --- a/tests/client/integration/content/tools/tabs/test_split_embed.py +++ b/tests/client/integration/content/tools/tabs/test_split_embed.py @@ -622,17 +622,18 @@ def test_create_new_vs_toggle_shown_when_vector_stores_exist(self, app_server, a # Ensure database has vector stores if at.session_state.database_configs: # Find matching model ID for the vector store - model_id = None + # Model format in vector stores must be "provider/model_id" to match enabled_models_lookup keys + model_key = None for model in at.session_state.model_configs: if model["type"] == "embed" and model.get("enabled"): - model_id = model["id"] + model_key = f"{model.get('provider')}/{model['id']}" break - if model_id: + if model_key: at.session_state.database_configs[0]["vector_stores"] = [ { "alias": "existing_vs", - "model": model_id, + "model": model_key, "vector_store": "VECTOR_STORE_TABLE", "chunk_size": 500, "chunk_overlap": 50, diff --git a/tests/client/integration/utils/test_st_common.py b/tests/client/integration/utils/test_st_common.py index 164ecb13..eae17b05 100644 --- a/tests/client/integration/utils/test_st_common.py +++ b/tests/client/integration/utils/test_st_common.py @@ -2,329 +2,19 @@ """ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Note: Vector store selection tests have been moved to test_vs_options.py +following the refactor that moved vector store functionality from st_common.py +to vs_options.py. """ # spell-checker: disable -from unittest.mock import patch - -import pandas as pd -import pytest -import streamlit as st -from streamlit import session_state as state - -from client.utils import st_common - - -############################################################################# -# Fixtures -############################################################################# -@pytest.fixture -def vector_store_state(sample_vector_store_data): - """Setup common vector store state for tests using shared test data""" - # Setup initial state with vector search settings - state.client_settings = { - "vector_search": { - "enabled": True, - **sample_vector_store_data, - "top_k": 10, - "search_type": "Similarity", - "score_threshold": 0.5, - "fetch_k": 20, - "lambda_mult": 0.5, - }, - "database": {"alias": "DEFAULT"}, - "ll_model": {"model": "gpt-4", "temperature": 0.8}, - } - - # Set widget states to simulate user selections - state.selected_vector_search_model = sample_vector_store_data["model"] - state.selected_vector_search_chunk_size = sample_vector_store_data["chunk_size"] - state.selected_vector_search_chunk_overlap = sample_vector_store_data["chunk_overlap"] - state.selected_vector_search_distance_metric = sample_vector_store_data["distance_metric"] - state.selected_vector_search_alias = sample_vector_store_data["alias"] - state.selected_vector_search_index_type = sample_vector_store_data["index_type"] - - yield state - - # Cleanup after test - for key in list(state.keys()): - if key.startswith("selected_vector_search_"): - del state[key] - - -############################################################################# -# Test Vector Store Reset Button Functionality - Integration Tests -############################################################################# -class TestVectorStoreResetButtonIntegration: - """Integration tests for vector store selection Reset button""" - - def test_reset_button_callback_execution(self, app_server, vector_store_state, sample_vector_store_data): - """Test that the Reset button callback is properly executed when clicked""" - assert app_server is not None - assert vector_store_state is not None - - reset_callback_executed = False - - def mock_button(label, **kwargs): - nonlocal reset_callback_executed - if "Reset" in label and "on_click" in kwargs: - # Execute the callback to simulate button click - kwargs["on_click"]() - reset_callback_executed = True - return True - - with ( - patch.object(st.sidebar, "subheader"), - patch.object(st.sidebar, "button", side_effect=mock_button), - patch.object(st.sidebar, "selectbox"), - patch.object(st, "info"), - ): - # Create test dataframe using shared test data - vs_df = pd.DataFrame([sample_vector_store_data]) - - # Mock enabled models - with patch.object(st_common, "enabled_models_lookup") as mock_models: - mock_models.return_value = {"openai/text-embed-3": {"id": "text-embed-3"}} - - # Call the function - st_common.render_vector_store_selection(vs_df) - - # Verify reset callback was executed - assert reset_callback_executed - - # Verify all widget states are cleared - assert state.selected_vector_search_model == "" - assert state.selected_vector_search_chunk_size == "" - assert state.selected_vector_search_chunk_overlap == "" - assert state.selected_vector_search_distance_metric == "" - assert state.selected_vector_search_alias == "" - assert state.selected_vector_search_index_type == "" - - # Verify client_settings are also cleared - assert state.client_settings["vector_search"]["model"] == "" - assert state.client_settings["vector_search"]["chunk_size"] == "" - assert state.client_settings["vector_search"]["chunk_overlap"] == "" - assert state.client_settings["vector_search"]["distance_metric"] == "" - assert state.client_settings["vector_search"]["vector_store"] == "" - assert state.client_settings["vector_search"]["alias"] == "" - assert state.client_settings["vector_search"]["index_type"] == "" - - def test_reset_preserves_non_vector_store_settings(self, app_server, vector_store_state, sample_vector_store_data): - """Test that Reset only affects vector store fields, not other settings""" - assert app_server is not None - assert vector_store_state is not None - - def mock_button(label, **kwargs): - if "Reset" in label and "on_click" in kwargs: - kwargs["on_click"]() - return True - - with ( - patch.object(st.sidebar, "subheader"), - patch.object(st.sidebar, "button", side_effect=mock_button), - patch.object(st.sidebar, "selectbox"), - patch.object(st, "info"), - ): - vs_df = pd.DataFrame([sample_vector_store_data]) - - with patch.object(st_common, "enabled_models_lookup") as mock_models: - mock_models.return_value = {"openai/text-embed-3": {"id": "text-embed-3"}} - st_common.render_vector_store_selection(vs_df) - - # Vector store fields should be cleared - assert state.client_settings["vector_search"]["model"] == "" - assert state.client_settings["vector_search"]["alias"] == "" - - # Other settings should be preserved - assert state.client_settings["vector_search"]["top_k"] == 10 - assert state.client_settings["vector_search"]["search_type"] == "Similarity" - assert state.client_settings["vector_search"]["score_threshold"] == 0.5 - assert state.client_settings["database"]["alias"] == "DEFAULT" - assert state.client_settings["ll_model"]["model"] == "gpt-4" - assert state.client_settings["ll_model"]["temperature"] == 0.8 - - def test_auto_population_after_reset_single_option(self, app_server, sample_vector_store_data): - """Test that fields with single options are auto-populated after reset""" - assert app_server is not None - - # Setup clean state - state.client_settings = { - "vector_search": { - "enabled": True, - "model": "", # Empty after reset - "chunk_size": "", - "chunk_overlap": "", - "distance_metric": "", - "vector_store": "", - "alias": "", - "index_type": "", - "top_k": 10, - "search_type": "Similarity", - "score_threshold": 0.5, - "fetch_k": 20, - "lambda_mult": 0.5, - }, - "database": {"alias": "DEFAULT"}, - } - - # Clear widget states (simulating post-reset state) - state.selected_vector_search_model = "" - state.selected_vector_search_chunk_size = "" - state.selected_vector_search_chunk_overlap = "" - state.selected_vector_search_distance_metric = "" - state.selected_vector_search_alias = "" - state.selected_vector_search_index_type = "" - - selectbox_calls = [] - - def mock_selectbox(label, options, key, index, disabled=False): - selectbox_calls.append( - {"label": label, "options": options, "key": key, "index": index, "disabled": disabled} - ) - # Return the value at index - return options[index] if 0 <= index < len(options) else "" - - with ( - patch.object(st.sidebar, "subheader"), - patch.object(st.sidebar, "button"), - patch.object(st.sidebar, "selectbox", side_effect=mock_selectbox), - patch.object(st, "info"), - ): - # Create dataframe with single option per field using shared fixture - single_vs = sample_vector_store_data.copy() - single_vs["alias"] = "single_alias" - single_vs["vector_store"] = "single_vs" - vs_df = pd.DataFrame([single_vs]) - - with patch.object(st_common, "enabled_models_lookup") as mock_models: - mock_models.return_value = {"openai/text-embed-3": {"id": "text-embed-3"}} - st_common.render_vector_store_selection(vs_df) - - # Verify auto-population happened for single options - assert state.client_settings["vector_search"]["alias"] == "single_alias" - assert state.client_settings["vector_search"]["model"] == sample_vector_store_data["model"] - assert state.client_settings["vector_search"]["chunk_size"] == sample_vector_store_data["chunk_size"] - assert state.client_settings["vector_search"]["chunk_overlap"] == sample_vector_store_data["chunk_overlap"] - assert ( - state.client_settings["vector_search"]["distance_metric"] - == sample_vector_store_data["distance_metric"] - ) - assert state.client_settings["vector_search"]["index_type"] == sample_vector_store_data["index_type"] - - # Verify widget states were also set - assert state.selected_vector_search_alias == "single_alias" - assert state.selected_vector_search_model == sample_vector_store_data["model"] - - def test_no_auto_population_with_multiple_options( - self, app_server, sample_vector_store_data, sample_vector_store_data_alt - ): - """Test that fields with multiple options are NOT auto-populated after reset""" - assert app_server is not None - - # Setup clean state after reset - state.client_settings = { - "vector_search": { - "enabled": True, - "model": "", - "chunk_size": "", - "chunk_overlap": "", - "distance_metric": "", - "vector_store": "", - "alias": "", - "index_type": "", - "top_k": 10, - "search_type": "Similarity", - "score_threshold": 0.5, - "fetch_k": 20, - "lambda_mult": 0.5, - }, - "database": {"alias": "DEFAULT"}, - } - - # Clear widget states - for key in ["model", "chunk_size", "chunk_overlap", "distance_metric", "alias", "index_type"]: - state[f"selected_vector_search_{key}"] = "" - - with ( - patch.object(st.sidebar, "subheader"), - patch.object(st.sidebar, "button"), - patch.object(st.sidebar, "selectbox", return_value=""), - patch.object(st, "info"), - ): - # Create dataframe with multiple options using shared fixtures - vs1 = sample_vector_store_data.copy() - vs1["alias"] = "alias1" - vs2 = sample_vector_store_data_alt.copy() - vs2["alias"] = "alias2" - vs_df = pd.DataFrame([vs1, vs2]) - - with patch.object(st_common, "enabled_models_lookup") as mock_models: - mock_models.return_value = {"openai/text-embed-3": {"id": "text-embed-3"}} - st_common.render_vector_store_selection(vs_df) - - # With multiple options, fields should remain empty (no auto-population) - assert state.client_settings["vector_search"]["alias"] == "" - assert state.client_settings["vector_search"]["chunk_size"] == "" - assert state.client_settings["vector_search"]["chunk_overlap"] == "" - assert state.client_settings["vector_search"]["distance_metric"] == "" - assert state.client_settings["vector_search"]["index_type"] == "" - - def test_reset_button_with_filtered_dataframe( - self, app_server, sample_vector_store_data, sample_vector_store_data_alt - ): - """Test reset button behavior with dynamically filtered dataframes""" - assert app_server is not None - - # Setup state with a filter already applied - state.client_settings = { - "vector_search": { - "enabled": True, - "model": sample_vector_store_data["model"], - "chunk_size": sample_vector_store_data["chunk_size"], - "chunk_overlap": "", - "distance_metric": "", - "vector_store": "", - "alias": "alias1", # Filter applied - "index_type": "", - "top_k": 10, - "search_type": "Similarity", - "score_threshold": 0.5, - "fetch_k": 20, - "lambda_mult": 0.5, - }, - "database": {"alias": "DEFAULT"}, - } - - state.selected_vector_search_alias = "alias1" - state.selected_vector_search_model = sample_vector_store_data["model"] - state.selected_vector_search_chunk_size = sample_vector_store_data["chunk_size"] - - def mock_button(label, **kwargs): - if "Reset" in label and "on_click" in kwargs: - kwargs["on_click"]() - return True - - with ( - patch.object(st.sidebar, "subheader"), - patch.object(st.sidebar, "button", side_effect=mock_button), - patch.object(st.sidebar, "selectbox", return_value=""), - patch.object(st, "info"), - ): - # Create dataframe with same alias using shared fixtures - vs1 = sample_vector_store_data.copy() - vs1["alias"] = "alias1" - vs2 = sample_vector_store_data_alt.copy() - vs2["alias"] = "alias1" - vs_df = pd.DataFrame([vs1, vs2]) - - with patch.object(st_common, "enabled_models_lookup") as mock_models: - mock_models.return_value = {"openai/text-embed-3": {"id": "text-embed-3"}} - st_common.render_vector_store_selection(vs_df) - - # After reset, all filters should be cleared - assert state.selected_vector_search_alias == "" - assert state.selected_vector_search_model == "" - assert state.selected_vector_search_chunk_size == "" - assert state.client_settings["vector_search"]["alias"] == "" - assert state.client_settings["vector_search"]["model"] == "" - assert state.client_settings["vector_search"]["chunk_size"] == "" +# This file previously contained integration tests for vector store selection +# functionality that was part of st_common.py. Those tests have been moved to: +# tests/client/integration/utils/test_vs_options.py +# +# The st_common.py module no longer contains vector store selection functions. +# See vs_options.py for: +# - vector_search_sidebar() +# - vector_store_selection() +# - Related helper functions (_get_vs_fields, _reset_selections, etc.) diff --git a/tests/client/unit/content/test_chatbot_unit.py b/tests/client/unit/content/test_chatbot_unit.py index 66b309ed..a01b04aa 100644 --- a/tests/client/unit/content/test_chatbot_unit.py +++ b/tests/client/unit/content/test_chatbot_unit.py @@ -34,14 +34,18 @@ def test_show_vector_search_refs_with_metadata(self, monkeypatch): mock_columns = MagicMock(return_value=[mock_col, mock_col, mock_col]) mock_subheader = MagicMock() + mock_expander = MagicMock() + mock_expander.__enter__ = MagicMock(return_value=mock_expander) + mock_expander.__exit__ = MagicMock(return_value=False) monkeypatch.setattr(st, "markdown", mock_markdown) monkeypatch.setattr(st, "columns", mock_columns) monkeypatch.setattr(st, "subheader", mock_subheader) + monkeypatch.setattr(st, "expander", MagicMock(return_value=mock_expander)) - # Create test context - context = [ - [ + # Create test context - now expects dict with "documents" key + context = { + "documents": [ { "page_content": "This is chunk 1 content", "metadata": {"filename": "doc1.pdf", "source": "/path/to/doc1.pdf", "page": 1}, @@ -55,8 +59,8 @@ def test_show_vector_search_refs_with_metadata(self, monkeypatch): "metadata": {"filename": "doc1.pdf", "source": "/path/to/doc1.pdf", "page": 3}, }, ], - "test query", - ] + "context_input": "test query", + } # Call function chatbot.show_vector_search_refs(context) @@ -64,9 +68,6 @@ def test_show_vector_search_refs_with_metadata(self, monkeypatch): # Verify References header was shown assert any("References" in str(call) for call in mock_markdown.call_args_list) - # Verify Notes with query shown - assert any("test query" in str(call) for call in mock_markdown.call_args_list) - def test_show_vector_search_refs_missing_metadata(self, monkeypatch): """Test showing vector search references when metadata is missing""" from client.content import chatbot @@ -83,21 +84,25 @@ def test_show_vector_search_refs_missing_metadata(self, monkeypatch): mock_columns = MagicMock(return_value=[mock_col]) mock_subheader = MagicMock() + mock_expander = MagicMock() + mock_expander.__enter__ = MagicMock(return_value=mock_expander) + mock_expander.__exit__ = MagicMock(return_value=False) monkeypatch.setattr(st, "markdown", mock_markdown) monkeypatch.setattr(st, "columns", mock_columns) monkeypatch.setattr(st, "subheader", mock_subheader) + monkeypatch.setattr(st, "expander", MagicMock(return_value=mock_expander)) - # Create test context with missing metadata - context = [ - [ + # Create test context with missing metadata - now expects dict with "documents" key + context = { + "documents": [ { "page_content": "Content without metadata", "metadata": {}, # Empty metadata - will cause KeyError } ], - "test query", - ] + "context_input": "test query", + } # Call function - should handle KeyError gracefully chatbot.show_vector_search_refs(context) @@ -138,7 +143,7 @@ def test_setup_sidebar_no_models(self, monkeypatch): def test_setup_sidebar_with_models(self, monkeypatch): """Test setup_sidebar with enabled language models""" from client.content import chatbot - from client.utils import st_common + from client.utils import st_common, vs_options from streamlit import session_state as state # Mock enabled_models_lookup to return models @@ -148,7 +153,7 @@ def test_setup_sidebar_with_models(self, monkeypatch): monkeypatch.setattr(st_common, "tools_sidebar", MagicMock()) monkeypatch.setattr(st_common, "history_sidebar", MagicMock()) monkeypatch.setattr(st_common, "ll_sidebar", MagicMock()) - monkeypatch.setattr(st_common, "vector_search_sidebar", MagicMock()) + monkeypatch.setattr(vs_options, "vector_search_sidebar", MagicMock()) # Initialize state state.enable_client = True @@ -162,7 +167,7 @@ def test_setup_sidebar_with_models(self, monkeypatch): def test_setup_sidebar_client_disabled(self, monkeypatch): """Test setup_sidebar when client gets disabled""" from client.content import chatbot - from client.utils import st_common + from client.utils import st_common, vs_options from streamlit import session_state as state import streamlit as st @@ -175,7 +180,7 @@ def disable_client(): monkeypatch.setattr(st_common, "tools_sidebar", disable_client) monkeypatch.setattr(st_common, "history_sidebar", MagicMock()) monkeypatch.setattr(st_common, "ll_sidebar", MagicMock()) - monkeypatch.setattr(st_common, "vector_search_sidebar", MagicMock()) + monkeypatch.setattr(vs_options, "vector_search_sidebar", MagicMock()) # Mock st.stop mock_stop = MagicMock(side_effect=SystemExit) @@ -308,10 +313,10 @@ def test_display_chat_history_with_vector_search(self, monkeypatch): mock_show_refs = MagicMock() monkeypatch.setattr(chatbot, "show_vector_search_refs", mock_show_refs) - # Create history with tool message + # Create history with tool message (tool name changed to optimizer_vs-retriever) vector_refs = [[{"page_content": "content", "metadata": {}}], "query"] history = [ - {"role": "tool", "name": "oraclevs_tool", "content": json.dumps(vector_refs)}, + {"role": "tool", "name": "optimizer_vs-retriever", "content": json.dumps(vector_refs)}, {"role": "ai", "content": "Based on the documents..."}, ] diff --git a/tests/client/unit/utils/test_st_common_unit.py b/tests/client/unit/utils/test_st_common_unit.py index 1884dc24..74791baf 100644 --- a/tests/client/unit/utils/test_st_common_unit.py +++ b/tests/client/unit/utils/test_st_common_unit.py @@ -2,14 +2,16 @@ """ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Note: Vector store helper tests have been moved to test_vs_options_unit.py +following the refactor that moved vector store functionality from st_common.py +to vs_options.py. """ # spell-checker: disable from io import BytesIO -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock -import pandas as pd -import streamlit as st from streamlit import session_state as state from client.utils import api_call, st_common @@ -396,275 +398,14 @@ def test_is_db_configured_false_different_alias(self, app_server): assert result is False -############################################################################# -# Test Vector Store Helpers -############################################################################# -class TestVectorStoreHelpers: - """Test vector store helper functions""" - - def test_update_filtered_vector_store_no_filters(self, app_server, sample_vector_stores_list): - """Test update_filtered_vector_store with no filters""" - assert app_server is not None - - state.model_configs = [ - {"id": "text-embed-3", "provider": "openai", "type": "embed", "enabled": True}, - ] - - vs_df = pd.DataFrame(sample_vector_stores_list) - - result = st_common.update_filtered_vector_store(vs_df) - - # Should return all rows (filtered by enabled models only) - assert len(result) == 2 - - def test_update_filtered_vector_store_with_alias_filter(self, app_server, sample_vector_stores_list): - """Test update_filtered_vector_store with alias filter""" - assert app_server is not None - - state.model_configs = [ - {"id": "text-embed-3", "provider": "openai", "type": "embed", "enabled": True}, - ] - state.selected_vector_search_alias = "vs1" - - vs_df = pd.DataFrame(sample_vector_stores_list) - - result = st_common.update_filtered_vector_store(vs_df) - - # Should only return vs1 - assert len(result) == 1 - assert result.iloc[0]["alias"] == "vs1" - - def test_update_filtered_vector_store_disabled_model(self, app_server, sample_vector_store_data): - """Test that disabled embedding models filter out vector stores""" - assert app_server is not None - - state.model_configs = [ - {"id": "text-embed-3", "provider": "openai", "type": "embed", "enabled": False}, - ] - - # Use shared fixture with vs1 alias - vs1 = sample_vector_store_data.copy() - vs1["alias"] = "vs1" - vs1.pop("vector_store", None) - vs_df = pd.DataFrame([vs1]) - - result = st_common.update_filtered_vector_store(vs_df) - - # Should return empty (model not enabled) - assert len(result) == 0 - - def test_update_filtered_vector_store_multiple_filters(self, app_server, sample_vector_stores_list): - """Test update_filtered_vector_store with multiple filters""" - assert app_server is not None - - state.model_configs = [ - {"id": "text-embed-3", "provider": "openai", "type": "embed", "enabled": True}, - ] - state.selected_vector_search_alias = "vs1" - state.selected_vector_search_model = "openai/text-embed-3" - state.selected_vector_search_chunk_size = 1000 - - # Use only vs1 entries from the fixture - vs1_entries = [vs.copy() for vs in sample_vector_stores_list] - for vs in vs1_entries: - vs["alias"] = "vs1" - - vs_df = pd.DataFrame(vs1_entries) - - result = st_common.update_filtered_vector_store(vs_df) - - # Should only return the 1000 chunk_size entry - assert len(result) == 1 - assert result.iloc[0]["chunk_size"] == 1000 - - -############################################################################# -# Test _vs_gen_selectbox Function -############################################################################# -class TestVsGenSelectbox: - """Unit tests for the _vs_gen_selectbox function""" - - def test_single_option_auto_select_when_empty(self, app_server): - """Test auto-selection when there's one option and current value is empty""" - assert app_server is not None - - # Setup: empty current value - state.client_settings = {"vector_search": {"alias": ""}} - - with patch.object(st.sidebar, "selectbox") as mock_selectbox: - mock_selectbox.return_value = "single_option" - - st_common._vs_gen_selectbox("Select Alias:", ["single_option"], "selected_vector_search_alias") - - # Verify auto-selection occurred - assert state.client_settings["vector_search"]["alias"] == "single_option" - assert state.selected_vector_search_alias == "single_option" - - # Verify selectbox was called with correct index (1 = first real option after empty) - mock_selectbox.assert_called_once() - call_args = mock_selectbox.call_args - assert call_args[1]["index"] == 1 # Index 1 points to "single_option" in ["", "single_option"] - - def test_single_option_no_auto_select_when_populated(self, app_server): - """Test NO auto-selection when there's one option but value already exists""" - assert app_server is not None - - # Setup: existing value - state.client_settings = {"vector_search": {"alias": "existing_value"}} - - with patch.object(st.sidebar, "selectbox") as mock_selectbox: - mock_selectbox.return_value = "existing_value" - - st_common._vs_gen_selectbox("Select Alias:", ["existing_value"], "selected_vector_search_alias") - - # Value should remain unchanged (not overwritten) - assert state.client_settings["vector_search"]["alias"] == "existing_value" - - # Verify selectbox was called with existing value's index - mock_selectbox.assert_called_once() - call_args = mock_selectbox.call_args - assert call_args[1]["index"] == 1 # existing_value is at index 1 - - def test_multiple_options_no_auto_select(self, app_server): - """Test no auto-selection with multiple options""" - assert app_server is not None - - # Setup: empty value with multiple options - state.client_settings = {"vector_search": {"alias": ""}} - - with patch.object(st.sidebar, "selectbox") as mock_selectbox: - mock_selectbox.return_value = "" - - st_common._vs_gen_selectbox( - "Select Alias:", ["option1", "option2", "option3"], "selected_vector_search_alias" - ) - - # Should remain empty (no auto-selection) - assert state.client_settings["vector_search"]["alias"] == "" - - # Verify selectbox was called with index 0 (empty option) - mock_selectbox.assert_called_once() - call_args = mock_selectbox.call_args - assert call_args[1]["index"] == 0 # Index 0 is the empty option - - def test_no_valid_options_disabled(self, app_server): - """Test selectbox is disabled when no valid options""" - assert app_server is not None - - state.client_settings = {"vector_search": {"alias": ""}} - - with patch.object(st.sidebar, "selectbox") as mock_selectbox: - mock_selectbox.return_value = "" - - st_common._vs_gen_selectbox( - "Select Alias:", - [], # No options - "selected_vector_search_alias", - ) - - # Verify selectbox was called with disabled=True - mock_selectbox.assert_called_once() - call_args = mock_selectbox.call_args - assert call_args[1]["disabled"] is True - assert call_args[1]["index"] == 0 - - def test_invalid_current_value_reset(self, app_server): - """Test that invalid current value is reset to empty""" - assert app_server is not None - - # Setup: value that's not in the options - state.client_settings = {"vector_search": {"alias": "invalid_option"}} - - with patch.object(st.sidebar, "selectbox") as mock_selectbox: - mock_selectbox.return_value = "" - - st_common._vs_gen_selectbox("Select Alias:", ["valid1", "valid2"], "selected_vector_search_alias") - - # Invalid value should not cause error, selectbox should show empty - mock_selectbox.assert_called_once() - call_args = mock_selectbox.call_args - assert call_args[1]["index"] == 0 # Reset to empty option - - -############################################################################# -# Test Reset Button Callback Function -############################################################################# -class TestResetButtonCallback: - """Unit tests for the reset button callback within render_vector_store_selection""" - - def test_reset_clears_correct_fields(self, app_server): - """Test reset callback clears only the specified vector store fields""" - assert app_server is not None - - # Setup initial values - state.client_settings = { - "vector_search": { - "model": "openai/text-embed-3", - "chunk_size": 1000, - "chunk_overlap": 200, - "distance_metric": "cosine", - "vector_store": "vs_test", - "alias": "test_alias", - "index_type": "IVF", - "top_k": 10, - "search_type": "Similarity", - } - } - - # Set widget states - state.selected_vector_search_model = "openai/text-embed-3" - state.selected_vector_search_chunk_size = 1000 - state.selected_vector_search_chunk_overlap = 200 - state.selected_vector_search_distance_metric = "cosine" - state.selected_vector_search_alias = "test_alias" - state.selected_vector_search_index_type = "IVF" - - # Define and execute reset logic (simulating the reset callback) - fields_to_reset = [ - "model", - "chunk_size", - "chunk_overlap", - "distance_metric", - "vector_store", - "alias", - "index_type", - ] - for key in fields_to_reset: - widget_key = f"selected_vector_search_{key}" - state[widget_key] = "" - state.client_settings["vector_search"][key] = "" - - # Verify the correct fields were cleared - for field in fields_to_reset: - assert state.client_settings["vector_search"][field] == "" - assert state[f"selected_vector_search_{field}"] == "" - - # Verify other fields were NOT cleared - assert state.client_settings["vector_search"]["top_k"] == 10 - assert state.client_settings["vector_search"]["search_type"] == "Similarity" - - def test_reset_enables_auto_population(self, app_server): - """Test that reset creates conditions for auto-population""" - assert app_server is not None - - # Setup with existing values - state.client_settings = {"vector_search": {"alias": "existing"}} - state.selected_vector_search_alias = "existing" - - # Execute reset logic - state.selected_vector_search_alias = "" - state.client_settings["vector_search"]["alias"] = "" - - # After reset, fields should be empty (ready for auto-population) - assert state.client_settings["vector_search"]["alias"] == "" - assert state.selected_vector_search_alias == "" - - # Now when _vs_gen_selectbox is called with a single option, it should auto-populate - with patch.object(st.sidebar, "selectbox") as mock_selectbox: - mock_selectbox.return_value = "auto_selected" - - st_common._vs_gen_selectbox("Select Alias:", ["auto_selected"], "selected_vector_search_alias") - - # Verify auto-population happened - assert state.client_settings["vector_search"]["alias"] == "auto_selected" - assert state.selected_vector_search_alias == "auto_selected" +# Note: Vector store helper tests (TestVectorStoreHelpers, TestVsGenSelectbox, +# TestResetButtonCallback) have been moved to test_vs_options_unit.py following +# the refactor that moved vector store functionality from st_common.py to vs_options.py. +# +# See tests/client/unit/utils/test_vs_options_unit.py for: +# - TestGetVsFields +# - TestGetValidOptions +# - TestAutoSelect +# - TestResetSelections +# - TestGetCurrentSelections +# - TestRenderSelectbox diff --git a/tests/server/integration/test_endpoints_settings.py b/tests/server/integration/test_endpoints_settings.py index 5cfde6c0..933d7a40 100644 --- a/tests/server/integration/test_endpoints_settings.py +++ b/tests/server/integration/test_endpoints_settings.py @@ -105,7 +105,7 @@ def test_settings_update(self, client, auth_headers): updated_settings = Settings( client="default", ll_model=LargeLanguageSettings(model="updated-model", chat_history=False), - vector_search=VectorSearchSettings(enabled=True, grading=False, search_type="Similarity", top_k=5), + vector_search=VectorSearchSettings(grade=False, search_type="Similarity", top_k=5), oci=OciSettings(auth_profile="UPDATED"), ) @@ -125,8 +125,7 @@ def test_settings_update(self, client, auth_headers): # Check that the values were updated assert new_settings["ll_model"]["model"] == "updated-model" assert new_settings["ll_model"]["chat_history"] is False - assert new_settings["vector_search"]["enabled"] is True - assert new_settings["vector_search"]["grading"] is False + assert new_settings["vector_search"]["grade"] is False assert new_settings["vector_search"]["top_k"] == 5 assert new_settings["oci"]["auth_profile"] == "UPDATED" From a32b10b70e30be45e46fc98e9597d3ab47e5c7ca Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Mon, 1 Dec 2025 08:14:51 +0000 Subject: [PATCH 20/20] Remove patches, fix bugs --- pyproject.toml | 5 +- src/common/schema.py | 20 +-- src/launch_server.py | 3 - src/server/api/utils/databases.py | 10 +- src/server/api/utils/oci.py | 3 + src/server/api/utils/testbed.py | 20 +-- src/server/api/v1/databases.py | 23 +-- src/server/api/v1/testbed.py | 34 ++--- src/server/patches/__init__.py | 0 src/server/patches/litellm_patch.py | 40 ------ .../patches/litellm_patch_oci_streaming.py | 132 ------------------ src/server/patches/litellm_patch_transform.py | 80 ----------- .../integration/test_endpoints_testbed.py | 2 +- 13 files changed, 62 insertions(+), 310 deletions(-) delete mode 100644 src/server/patches/__init__.py delete mode 100644 src/server/patches/litellm_patch.py delete mode 100644 src/server/patches/litellm_patch_oci_streaming.py delete mode 100644 src/server/patches/litellm_patch_transform.py diff --git a/pyproject.toml b/pyproject.toml index 4ccd07d8..2e4f5c33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ server = [ "langchain-aimlapi==0.1.0", "langchain-cohere==0.4.6", "langchain-community==0.3.31", - "langchain-fireworks==0.3.0", + #"langchain-fireworks==0.3.0", "langchain-google-genai==2.1.12", "langchain-ibm==0.3.20", "langchain-mcp-adapters==0.1.13", @@ -43,7 +43,7 @@ server = [ "langchain-openai==0.3.35", "langchain-together==0.3.1", "langgraph==1.0.1", - "litellm==1.80.0", + "litellm==1.80.7", "llama-index==0.14.8", "lxml==6.0.2", "matplotlib==3.10.7", @@ -70,6 +70,7 @@ test = [ "pytest", "pytest-asyncio", "pytest-cov", + "types-jsonschema", "yamllint" ] diff --git a/src/common/schema.py b/src/common/schema.py index 057487a0..8d761c09 100644 --- a/src/common/schema.py +++ b/src/common/schema.py @@ -346,18 +346,18 @@ class ChatRequest(LanguageModelParameters): ##################################################### # Testbed ##################################################### -class TestSets(BaseModel): - """TestSets""" +class QASets(BaseModel): + """QA Sets - Collection of Q&A test sets for testbed evaluation""" tid: str = Field(description="Test ID") - name: str = Field(description="Name of TestSet") - created: str = Field(description="Date TestSet Loaded") + name: str = Field(description="Name of QA Set") + created: str = Field(description="Date QA Set Loaded") -class TestSetQA(BaseModel): - """TestSet Q&A""" +class QASetData(BaseModel): + """QA Set Data - Question/Answer pairs for testbed evaluation""" - qa_data: list = Field(description="TestSet Q&A Data") + qa_data: list = Field(description="QA Set Data") class Evaluation(BaseModel): @@ -390,6 +390,6 @@ class EvaluationReport(Evaluation): ModelEnabledType = ModelAccess.__annotations__["enabled"] OCIProfileType = OracleCloudSettings.__annotations__["auth_profile"] OCIResourceOCID = OracleResource.__annotations__["ocid"] -TestSetsIdType = TestSets.__annotations__["tid"] -TestSetsNameType = TestSets.__annotations__["name"] -TestSetDateType = TestSets.__annotations__["created"] +QASetsIdType = QASets.__annotations__["tid"] +QASetsNameType = QASets.__annotations__["name"] +QASetsDateType = QASets.__annotations__["created"] diff --git a/src/launch_server.py b/src/launch_server.py index 5f2c2453..5bb0da31 100644 --- a/src/launch_server.py +++ b/src/launch_server.py @@ -5,9 +5,6 @@ # spell-checker:ignore configfile fastmcp noauth getpid procs litellm giskard ollama # spell-checker:ignore dotenv apiserver laddr -# Patch litellm for Giskard/Ollama issue -import server.patches.litellm_patch # pylint: disable=unused-import, wrong-import-order - # Set OS Environment before importing other modules # Set OS Environment (Don't move their position to reflect on imports) # pylint: disable=wrong-import-position diff --git a/src/server/api/utils/databases.py b/src/server/api/utils/databases.py index 8ef8735a..c46e0faf 100644 --- a/src/server/api/utils/databases.py +++ b/src/server/api/utils/databases.py @@ -100,12 +100,14 @@ def _test(config: Database) -> None: except oracledb.DatabaseError: logger.info("Refreshing %s database connection.", config.name) _ = connect(config) - except ValueError as ex: - raise DbException(status_code=400, detail=f"Database: {str(ex)}") from ex + except DbException: + raise except PermissionError as ex: raise DbException(status_code=401, detail=f"Database: {str(ex)}") from ex except ConnectionError as ex: raise DbException(status_code=503, detail=f"Database: {str(ex)}") from ex + except ValueError as ex: + raise DbException(status_code=400, detail=f"Database: {str(ex)}") from ex except Exception as ex: raise DbException(status_code=500, detail=str(ex)) from ex @@ -136,7 +138,7 @@ def connect(config: Database) -> oracledb.Connection: include_fields = set(DatabaseAuth.model_fields.keys()) db_authn = config.model_dump(include=include_fields) if any(not db_authn[key] for key in ("user", "password", "dsn")): - raise ValueError("missing connection details") + raise DbException(status_code=400, detail=f"Database: {config.name} missing connection details.") logger.info("Connecting to Database: %s", config.dsn) # If a wallet password is provided but no wallet location is set @@ -249,7 +251,7 @@ def get_databases( for db in databases: try: db_conn = connect(config=db) - except (ValueError, PermissionError, ConnectionError, LookupError): + except (DbException, PermissionError, ConnectionError, LookupError): continue db.vector_stores = _get_vs(db_conn) db.connected = True diff --git a/src/server/api/utils/oci.py b/src/server/api/utils/oci.py index ff0b010f..33904025 100644 --- a/src/server/api/utils/oci.py +++ b/src/server/api/utils/oci.py @@ -185,6 +185,9 @@ def get_namespace(config: OracleCloudSettings) -> str: client = init_client(client_type, config) config.namespace = client.get_namespace().data logger.info("OCI: Namespace = %s", config.namespace) + except OciException: + # Re-raise OciException from init_client without wrapping + raise except oci.exceptions.InvalidConfig as ex: raise OciException(status_code=400, detail="Invalid Config") from ex except oci.exceptions.ServiceError as ex: diff --git a/src/server/api/utils/testbed.py b/src/server/api/utils/testbed.py index 21a790d3..dc3d252e 100644 --- a/src/server/api/utils/testbed.py +++ b/src/server/api/utils/testbed.py @@ -95,14 +95,14 @@ def get_testsets(db_conn: Connection) -> list: sql = "SELECT tid, name, to_char(created) FROM oai_testsets ORDER BY created" results = utils_databases.execute_sql(db_conn, sql) try: - testsets = [schema.TestSets(tid=tid.hex(), name=name, created=created) for tid, name, created in results] + testsets = [schema.QASets(tid=tid.hex(), name=name, created=created) for tid, name, created in results] except TypeError: create_testset_objects(db_conn) return testsets -def get_testset_qa(db_conn: Connection, tid: schema.TestSetsIdType) -> schema.TestSetQA: +def get_testset_qa(db_conn: Connection, tid: schema.QASetsIdType) -> schema.QASetData: """Get list of TestSet Q&A""" logger.info("Getting TestSet Q&A for TID: %s", tid) binds = {"tid": tid} @@ -110,10 +110,10 @@ def get_testset_qa(db_conn: Connection, tid: schema.TestSetsIdType) -> schema.Te results = utils_databases.execute_sql(db_conn, sql, binds) qa_data = [qa_data[0] for qa_data in results] - return schema.TestSetQA(qa_data=qa_data) + return schema.QASetData(qa_data=qa_data) -def get_evaluations(db_conn: Connection, tid: schema.TestSetsIdType) -> list[schema.Evaluation]: +def get_evaluations(db_conn: Connection, tid: schema.QASetsIdType) -> list[schema.Evaluation]: """Get list of Evaluations for a TID""" logger.info("Getting Evaluations for: %s", tid) evaluations = [] @@ -133,7 +133,7 @@ def get_evaluations(db_conn: Connection, tid: schema.TestSetsIdType) -> list[sch def delete_qa( db_conn: Connection, - tid: schema.TestSetsIdType, + tid: schema.QASetsIdType, ) -> None: """Delete Q&A""" binds = {"tid": tid} @@ -144,11 +144,11 @@ def delete_qa( def upsert_qa( db_conn: Connection, - name: schema.TestSetsNameType, - created: schema.TestSetDateType, + name: schema.QASetsNameType, + created: schema.QASetsDateType, json_data: json, - tid: schema.TestSetsIdType = None, -) -> schema.TestSetsIdType: + tid: schema.QASetsIdType = None, +) -> schema.QASetsIdType: """Upsert Q&A""" logger.info("Upsert TestSet: %s - %s", name, created) parsed_data = json.loads(json_data) @@ -270,7 +270,7 @@ def build_knowledge_base( return testset -def process_report(db_conn: Connection, eid: schema.TestSetsIdType) -> schema.EvaluationReport: +def process_report(db_conn: Connection, eid: schema.QASetsIdType) -> schema.EvaluationReport: """Process an evaluate report""" # Main diff --git a/src/server/api/v1/databases.py b/src/server/api/v1/databases.py index acf33168..583e5f37 100644 --- a/src/server/api/v1/databases.py +++ b/src/server/api/v1/databases.py @@ -76,24 +76,25 @@ async def databases_update( db.connected = False try: - payload.config_dir = db.config_dir - payload.wallet_location = db.wallet_location - logger.debug("Testing Payload: %s", payload) - db_conn = utils_databases.connect(payload) - except (ValueError, PermissionError, ConnectionError, LookupError) as ex: + # Create a test config with payload values to test connection + # Only update the actual db object after successful connection + test_config = db.model_copy(update=payload.model_dump(exclude_unset=True)) + logger.debug("Testing Database: %s", test_config) + db_conn = utils_databases.connect(test_config) + except utils_databases.DbException as ex: + raise HTTPException(status_code=ex.status_code, detail=ex.detail) from ex + except (PermissionError, ConnectionError, LookupError) as ex: status_code = 500 - if isinstance(ex, ValueError): - status_code = 400 - elif isinstance(ex, PermissionError): + if isinstance(ex, PermissionError): status_code = 401 elif isinstance(ex, LookupError): status_code = 404 elif isinstance(ex, ConnectionError): status_code = 503 - else: - raise raise HTTPException(status_code=status_code, detail=f"Database: {db.name} {ex}.") from ex - for key, value in payload.model_dump().items(): + + # Connection successful - now update the actual db object + for key, value in payload.model_dump(exclude_unset=True).items(): setattr(db, key, value) # Manage Connections; Unset and disconnect other databases diff --git a/src/server/api/v1/testbed.py b/src/server/api/v1/testbed.py index 6705b149..035f3a0b 100644 --- a/src/server/api/v1/testbed.py +++ b/src/server/api/v1/testbed.py @@ -40,11 +40,11 @@ @auth.get( "/testsets", description="Get Stored TestSets.", - response_model=list[schema.TestSets], + response_model=list[schema.QASets], ) async def testbed_testsets( client: schema.ClientIdType = Header(default="server"), -) -> list[schema.TestSets]: +) -> list[schema.QASets]: """Get a list of stored TestSets, create TestSet objects if they don't exist""" testsets = utils_testbed.get_testsets(db_conn=utils_databases.get_client_database(client).connection) return testsets @@ -56,7 +56,7 @@ async def testbed_testsets( response_model=list[schema.Evaluation], ) async def testbed_evaluations( - tid: schema.TestSetsIdType, + tid: schema.QASetsIdType, client: schema.ClientIdType = Header(default="server"), ) -> list[schema.Evaluation]: """Get Evaluations""" @@ -72,7 +72,7 @@ async def testbed_evaluations( response_model=schema.EvaluationReport, ) async def testbed_evaluation( - eid: schema.TestSetsIdType, + eid: schema.QASetsIdType, client: schema.ClientIdType = Header(default="server"), ) -> schema.EvaluationReport: """Get Evaluations""" @@ -84,13 +84,13 @@ async def testbed_evaluation( @auth.get( "/testset_qa", - description="Get Stored schema.TestSets Q&A.", - response_model=schema.TestSetQA, + description="Get Stored Testbed Q&A.", + response_model=schema.QASetData, ) async def testbed_testset_qa( - tid: schema.TestSetsIdType, + tid: schema.QASetsIdType, client: schema.ClientIdType = Header(default="server"), -) -> schema.TestSetQA: +) -> schema.QASetData: """Get TestSet Q&A""" return utils_testbed.get_testset_qa( db_conn=utils_databases.get_client_database(client).connection, tid=tid.upper() @@ -102,7 +102,7 @@ async def testbed_testset_qa( description="Delete a TestSet", ) async def testbed_delete_testset( - tid: Optional[schema.TestSetsIdType] = None, + tid: Optional[schema.QASetsIdType] = None, client: schema.ClientIdType = Header(default="server"), ) -> JSONResponse: """Delete TestSet""" @@ -113,14 +113,14 @@ async def testbed_delete_testset( @auth.post( "/testset_load", description="Upsert TestSets.", - response_model=schema.TestSetQA, + response_model=schema.QASetData, ) async def testbed_upsert_testsets( files: list[UploadFile], - name: schema.TestSetsNameType, - tid: Optional[schema.TestSetsIdType] = None, + name: schema.QASetsNameType, + tid: Optional[schema.QASetsIdType] = None, client: schema.ClientIdType = Header(default="server"), -) -> schema.TestSetQA: +) -> schema.QASetData: """Update stored TestSet data""" created = datetime.now().isoformat() db_conn = utils_databases.get_client_database(client).connection @@ -194,16 +194,16 @@ def _handle_testset_error(ex: Exception, temp_directory, ll_model: str): @auth.post( "/testset_generate", description="Generate Q&A Test Set.", - response_model=schema.TestSetQA, + response_model=schema.QASetData, ) async def testbed_generate_qa( files: list[UploadFile], - name: schema.TestSetsNameType, + name: schema.QASetsNameType, ll_model: str, embed_model: str, questions: int = 2, client: schema.ClientIdType = Header(default="server"), -) -> schema.TestSetQA: +) -> schema.QASetData: """Retrieve contents from a local file uploaded and generate Q&A""" # Get the Model Configuration try: @@ -249,7 +249,7 @@ async def _collect_testbed_answers(loaded_testset: QATestset, client: str) -> li response_model=schema.EvaluationReport, ) async def testbed_evaluate( - tid: schema.TestSetsIdType, + tid: schema.QASetsIdType, judge: str, client: schema.ClientIdType = Header(default="server"), ) -> schema.EvaluationReport: diff --git a/src/server/patches/__init__.py b/src/server/patches/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/server/patches/litellm_patch.py b/src/server/patches/litellm_patch.py deleted file mode 100644 index d8736e9e..00000000 --- a/src/server/patches/litellm_patch.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -LiteLLM Patch Orchestrator -========================== -This module serves as the entry point for all litellm patches. -It imports and applies patches from specialized modules: - -- litellm_patch_transform: Ollama transform_response patch for non-streaming responses -- litellm_patch_oci_auth: OCI authentication patches (instance principals, request signing) -- litellm_patch_oci_streaming: OCI streaming patches (tool call field fixes) - -All patches use guard checks to prevent double-patching. -""" -# spell-checker:ignore litellm - -from common import logging_config - -logger = logging_config.logging.getLogger("patches.litellm_patch") - -logger.info("Loading litellm patches...") - -# Import patch modules - they apply patches on import -# pylint: disable=unused-import -try: - from . import litellm_patch_transform - - logger.info("✓ Ollama transform_response patch loaded") -except Exception as e: - logger.error("✗ Failed to load Ollama transform patch: %s", e) - -try: - from . import litellm_patch_oci_streaming - - logger.info("✓ OCI streaming patches loaded (handle_generic_stream_chunk)") -except Exception as e: - logger.error("✗ Failed to load OCI streaming patches: %s", e) - -logger.info("All litellm patches loaded successfully") diff --git a/src/server/patches/litellm_patch_oci_streaming.py b/src/server/patches/litellm_patch_oci_streaming.py deleted file mode 100644 index bf3db3eb..00000000 --- a/src/server/patches/litellm_patch_oci_streaming.py +++ /dev/null @@ -1,132 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -OCI Streaming Patches -===================== -Patches for OCI GenAI service streaming responses with tool calls. - -Issue: OCI API returns tool calls without 'arguments' field, causing Pydantic validation error -Error: ValidationError: 1 validation error for OCIStreamChunk message.toolCalls.0.arguments Field required - -This happens when OCI models (e.g., meta.llama-3.1-405b-instruct) attempt tool calling but return -incomplete tool call structures missing the required 'arguments' field during streaming. - -This module patches OCIStreamWrapper._handle_generic_stream_chunk to add missing required fields -with empty defaults before Pydantic validation. -""" -# spell-checker:ignore litellm giskard ollama llms -# pylint: disable=unused-argument,protected-access - -from common import logging_config - -logger = logging_config.logging.getLogger("patches.litellm_patch_oci_streaming") - -# Patch OCI _handle_generic_stream_chunk to add missing 'arguments' field in tool calls -try: - from litellm.llms.oci.chat.transformation import OCIStreamWrapper - - original_handle_generic_stream_chunk = getattr(OCIStreamWrapper, "_handle_generic_stream_chunk", None) -except ImportError: - original_handle_generic_stream_chunk = None - -if original_handle_generic_stream_chunk and not getattr( - original_handle_generic_stream_chunk, "_is_custom_patch", False -): - from litellm.llms.oci.chat.transformation import ( - OCIStreamChunk, - OCITextContentPart, - OCIImageContentPart, - adapt_tools_to_openai_standard, - ) - from litellm.types.utils import ModelResponseStream, StreamingChoices, Delta - - def _fix_missing_tool_call_fields(tool_call: dict) -> list: - """Add missing required fields to tool call and return list of missing fields""" - missing_fields = [] - if "arguments" not in tool_call: - tool_call["arguments"] = "" - missing_fields.append("arguments") - if "id" not in tool_call: - tool_call["id"] = "" - missing_fields.append("id") - if "name" not in tool_call: - tool_call["name"] = "" - missing_fields.append("name") - return missing_fields - - def _patch_tool_calls(dict_chunk: dict) -> None: - """Fix missing required fields in tool calls before Pydantic validation""" - if dict_chunk.get("message") and dict_chunk["message"].get("toolCalls"): - for tool_call in dict_chunk["message"]["toolCalls"]: - missing_fields = _fix_missing_tool_call_fields(tool_call) - if missing_fields: - logger.debug( - "OCI tool call streaming chunk missing fields: %s (Type: %s) - adding empty defaults", - missing_fields, - tool_call.get("type", "unknown"), - ) - - def _extract_text_content(typed_chunk: OCIStreamChunk) -> str: - """Extract text content from chunk message""" - text = "" - if typed_chunk.message and typed_chunk.message.content: - for item in typed_chunk.message.content: - if isinstance(item, OCITextContentPart): - text += item.text - elif isinstance(item, OCIImageContentPart): - raise ValueError("OCI does not support image content in streaming responses") - else: - raise ValueError(f"Unsupported content type in OCI response: {item.type}") - return text - - def custom_handle_generic_stream_chunk(self, dict_chunk: dict): - """ - Custom handler to fix missing 'arguments' field in OCI tool calls. - - OCI API sometimes returns tool calls with structure: - {'type': 'FUNCTION', 'id': '...', 'name': 'tool_name'} - - But OCIStreamChunk Pydantic model requires 'arguments' field in tool calls. - This patch adds an empty arguments dict if missing. - """ - # Fix missing required fields in tool calls before Pydantic validation - # OCI streams tool calls progressively, so early chunks may be missing required fields - _patch_tool_calls(dict_chunk) - - # Now proceed with original validation and processing - try: - typed_chunk = OCIStreamChunk(**dict_chunk) - except TypeError as e: - raise ValueError(f"Chunk cannot be casted to OCIStreamChunk: {str(e)}") from e - - if typed_chunk.index is None: - typed_chunk.index = 0 - - text = _extract_text_content(typed_chunk) - - tool_calls = None - if typed_chunk.message and typed_chunk.message.toolCalls: - tool_calls = adapt_tools_to_openai_standard(typed_chunk.message.toolCalls) - - return ModelResponseStream( - choices=[ - StreamingChoices( - index=typed_chunk.index if typed_chunk.index else 0, - delta=Delta( - content=text, - tool_calls=[tool.model_dump() for tool in tool_calls] if tool_calls else None, - provider_specific_fields=None, - thinking_blocks=None, - reasoning_content=None, - ), - finish_reason=typed_chunk.finishReason, - ) - ] - ) - - # Mark it to avoid double patching - custom_handle_generic_stream_chunk._is_custom_patch = True - - # Patch it - OCIStreamWrapper._handle_generic_stream_chunk = custom_handle_generic_stream_chunk diff --git a/src/server/patches/litellm_patch_transform.py b/src/server/patches/litellm_patch_transform.py deleted file mode 100644 index 2bba1f26..00000000 --- a/src/server/patches/litellm_patch_transform.py +++ /dev/null @@ -1,80 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker:ignore litellm giskard ollama llms -# pylint: disable=unused-argument,protected-access - -from typing import TYPE_CHECKING, List, Any -import time -import litellm -from litellm.llms.ollama.completion.transformation import OllamaConfig -from litellm.types.llms.openai import AllMessageValues -from litellm.types.utils import ModelResponse -from httpx._models import Response - -from common import logging_config - -logger = logging_config.logging.getLogger("patches.litellm_patch_transform") - -# Only patch if not already patched -if not getattr(OllamaConfig.transform_response, "_is_custom_patch", False): - if TYPE_CHECKING: - from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj - - LiteLLMLoggingObj = _LiteLLMLoggingObj - else: - LiteLLMLoggingObj = Any - - def custom_transform_response( - self, - model: str, - raw_response: Response, - model_response: ModelResponse, - logging_obj: LiteLLMLoggingObj, - request_data: dict, - messages: List[AllMessageValues], - optional_params: dict, - litellm_params: dict, - encoding: str, - **kwargs, - ): - """ - Custom transform response from - .venv/lib/python3.11/site-packages/litellm/llms/ollama/completion/transformation.py - - Additional kwargs: - api_key: Optional[str] - API key for authentication - json_mode: Optional[bool] - JSON mode flag - """ - logger.info("Custom transform_response is running") - response_json = raw_response.json() - - model_response.choices[0].finish_reason = "stop" - model_response.choices[0].message.content = response_json["response"] - - _prompt = request_data.get("prompt", "") - prompt_tokens = response_json.get( - "prompt_eval_count", - len(encoding.encode(_prompt, disallowed_special=())), - ) - completion_tokens = response_json.get("eval_count", len(response_json.get("message", {}).get("content", ""))) - - setattr( - model_response, - "usage", - litellm.Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), - ) - model_response.created = int(time.time()) - model_response.model = "ollama/" + model - return model_response - - # Mark it to avoid double patching - custom_transform_response._is_custom_patch = True - - # Patch it - OllamaConfig.transform_response = custom_transform_response diff --git a/tests/server/integration/test_endpoints_testbed.py b/tests/server/integration/test_endpoints_testbed.py index 40430599..6d0e35d9 100644 --- a/tests/server/integration/test_endpoints_testbed.py +++ b/tests/server/integration/test_endpoints_testbed.py @@ -10,7 +10,7 @@ from unittest.mock import patch, MagicMock import pytest from conftest import get_test_db_payload -from common.schema import TestSetQA as QATestSet, Evaluation, EvaluationReport +from common.schema import QASetData as QATestSet, Evaluation, EvaluationReport #############################################################################