diff --git a/backend/cpp/llama-cpp/grpc-server.cpp b/backend/cpp/llama-cpp/grpc-server.cpp index a33dc5c20da3..790032d60316 100644 --- a/backend/cpp/llama-cpp/grpc-server.cpp +++ b/backend/cpp/llama-cpp/grpc-server.cpp @@ -822,6 +822,12 @@ class BackendServiceImpl final : public backend::Backend::Service { } ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool { + // Check if context is cancelled before processing result + if (context->IsCancelled()) { + ctx_server.cancel_tasks(task_ids); + return false; + } + json res_json = result->to_json(); if (res_json.is_array()) { for (const auto & res : res_json) { @@ -875,13 +881,18 @@ class BackendServiceImpl final : public backend::Backend::Service { reply.set_message(error_data.value("content", "")); writer->Write(reply); return true; - }, [&]() { - // NOTE: we should try to check when the writer is closed here - return false; + }, [&context]() { + // Check if the gRPC context is cancelled + return context->IsCancelled(); }); ctx_server.queue_results.remove_waiting_task_ids(task_ids); + // Check if context was cancelled during processing + if (context->IsCancelled()) { + return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); + } + return grpc::Status::OK; } @@ -1145,6 +1156,14 @@ class BackendServiceImpl final : public backend::Backend::Service { std::cout << "[DEBUG] Waiting for results..." << std::endl; + + // Check cancellation before waiting for results + if (context->IsCancelled()) { + ctx_server.cancel_tasks(task_ids); + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); + } + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { std::cout << "[DEBUG] Received " << results.size() << " results" << std::endl; if (results.size() == 1) { @@ -1176,13 +1195,20 @@ class BackendServiceImpl final : public backend::Backend::Service { }, [&](const json & error_data) { std::cout << "[DEBUG] Error in results: " << error_data.value("content", "") << std::endl; reply->set_message(error_data.value("content", "")); - }, [&]() { - return false; + }, [&context]() { + // Check if the gRPC context is cancelled + // This is checked every HTTP_POLLING_SECONDS (1 second) during receive_multi_results + return context->IsCancelled(); }); ctx_server.queue_results.remove_waiting_task_ids(task_ids); std::cout << "[DEBUG] Predict request completed successfully" << std::endl; + // Check if context was cancelled during processing + if (context->IsCancelled()) { + return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); + } + return grpc::Status::OK; } @@ -1234,6 +1260,13 @@ class BackendServiceImpl final : public backend::Backend::Service { ctx_server.queue_tasks.post(std::move(tasks)); } + // Check cancellation before waiting for results + if (context->IsCancelled()) { + ctx_server.cancel_tasks(task_ids); + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); + } + // get the result ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { for (auto & res : results) { @@ -1242,12 +1275,18 @@ class BackendServiceImpl final : public backend::Backend::Service { } }, [&](const json & error_data) { error = true; - }, [&]() { - return false; + }, [&context]() { + // Check if the gRPC context is cancelled + return context->IsCancelled(); }); ctx_server.queue_results.remove_waiting_task_ids(task_ids); + // Check if context was cancelled during processing + if (context->IsCancelled()) { + return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); + } + if (error) { return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results"); } @@ -1325,6 +1364,13 @@ class BackendServiceImpl final : public backend::Backend::Service { ctx_server.queue_tasks.post(std::move(tasks)); } + // Check cancellation before waiting for results + if (context->IsCancelled()) { + ctx_server.cancel_tasks(task_ids); + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); + } + // Get the results ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { for (auto & res : results) { @@ -1333,12 +1379,18 @@ class BackendServiceImpl final : public backend::Backend::Service { } }, [&](const json & error_data) { error = true; - }, [&]() { - return false; + }, [&context]() { + // Check if the gRPC context is cancelled + return context->IsCancelled(); }); ctx_server.queue_results.remove_waiting_task_ids(task_ids); + // Check if context was cancelled during processing + if (context->IsCancelled()) { + return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); + } + if (error) { return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results"); } diff --git a/core/config/model_config.go b/core/config/model_config.go index 87fa05fce8ca..1dee82363d2f 100644 --- a/core/config/model_config.go +++ b/core/config/model_config.go @@ -93,19 +93,18 @@ type AgentConfig struct { EnablePlanReEvaluator bool `yaml:"enable_plan_re_evaluator" json:"enable_plan_re_evaluator"` } -func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers]) { +func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) { var remote MCPGenericConfig[MCPRemoteServers] var stdio MCPGenericConfig[MCPSTDIOServers] if err := yaml.Unmarshal([]byte(c.Servers), &remote); err != nil { - return remote, stdio + return remote, stdio, err } if err := yaml.Unmarshal([]byte(c.Stdio), &stdio); err != nil { - return remote, stdio + return remote, stdio, err } - - return remote, stdio + return remote, stdio, nil } type MCPGenericConfig[T any] struct { diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index d1ce156215c4..3691e058aa19 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -3,8 +3,10 @@ package openai import ( "bufio" "bytes" + "context" "encoding/json" "fmt" + "net" "time" "github.com/gofiber/fiber/v2" @@ -22,6 +24,59 @@ import ( "github.com/valyala/fasthttp" ) +// NOTE: this is a bad WORKAROUND! We should find a better way to handle this. +// Fasthttp doesn't support context cancellation from the caller +// for non-streaming requests, so we need to monitor the connection directly. +// Monitor connection for client disconnection during non-streaming requests +// We access the connection directly via c.Context().Conn() to monitor it +// during ComputeChoices execution, not after the response is sent +// see: https://github.com/mudler/LocalAI/pull/7187#issuecomment-3506720906 +func handleConnectionCancellation(c *fiber.Ctx, cancelFunc func(), requestCtx context.Context) { + var conn net.Conn = c.Context().Conn() + if conn == nil { + return + } + + go func() { + defer func() { + // Clear read deadline when goroutine exits + conn.SetReadDeadline(time.Time{}) + }() + + buf := make([]byte, 1) + // Use a short read deadline to periodically check if connection is closed + // Without a deadline, Read() would block indefinitely waiting for data + // that will never come (client is waiting for response, not sending more data) + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-requestCtx.Done(): + // Request completed or was cancelled - exit goroutine + return + case <-ticker.C: + // Set a short deadline - if connection is closed, read will fail immediately + // If connection is open but no data, it will timeout and we check again + conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) + _, err := conn.Read(buf) + if err != nil { + // Check if it's a timeout (connection still open, just no data) + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + // Timeout is expected - connection is still open, just no data to read + // Continue the loop to check again + continue + } + // Connection closed or other error - cancel the context to stop gRPC call + log.Debug().Msgf("Calling cancellation function") + cancelFunc() + return + } + } + } + }() +} + // ChatEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/chat/create // @Summary Generate a chat completions for a given prompt and model. // @Param request body schema.OpenAIRequest true "query params" @@ -358,6 +413,11 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator LOOP: for { select { + case <-input.Context.Done(): + // Context was cancelled (client disconnected or request cancelled) + log.Debug().Msgf("Request context cancelled, stopping stream") + input.Cancel() + break LOOP case ev := <-responses: if len(ev.Choices) == 0 { log.Debug().Msgf("No choices in the response, skipping") @@ -511,6 +571,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator } + // NOTE: this is a workaround as fasthttp + // context cancellation does not fire in non-streaming requests + handleConnectionCancellation(c, input.Cancel, input.Context) + result, tokenUsage, err := ComputeChoices( input, predInput, diff --git a/core/http/endpoints/openai/mcp.go b/core/http/endpoints/openai/mcp.go index 89d6aa5fa6f6..fe018bbbd09c 100644 --- a/core/http/endpoints/openai/mcp.go +++ b/core/http/endpoints/openai/mcp.go @@ -1,6 +1,7 @@ package openai import ( + "context" "encoding/json" "errors" "fmt" @@ -50,12 +51,15 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, } // Get MCP config from model config - remote, stdio := config.MCP.MCPConfigFromYAML() + remote, stdio, err := config.MCP.MCPConfigFromYAML() + if err != nil { + return fmt.Errorf("failed to get MCP config: %w", err) + } // Check if we have tools in cache, or we have to have an initial connection sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio) if err != nil { - return err + return fmt.Errorf("failed to get MCP sessions: %w", err) } if len(sessions) == 0 { @@ -73,6 +77,10 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, if appConfig.ApiKeys != nil { apiKey = appConfig.ApiKeys[0] } + + ctxWithCancellation, cancel := context.WithCancel(ctx) + defer cancel() + handleConnectionCancellation(c, cancel, ctxWithCancellation) // TODO: instead of connecting to the API, we should just wire this internally // and act like completion.go. // We can do this as cogito expects an interface and we can create one that @@ -83,7 +91,7 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, cogito.WithStatusCallback(func(s string) { log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s) }), - cogito.WithContext(ctx), + cogito.WithContext(ctxWithCancellation), cogito.WithMCPs(sessions...), cogito.WithIterations(3), // default to 3 iterations cogito.WithMaxAttempts(3), // default to 3 attempts diff --git a/core/http/middleware/request.go b/core/http/middleware/request.go index 35f39f7f37f9..4ec9613711c2 100644 --- a/core/http/middleware/request.go +++ b/core/http/middleware/request.go @@ -161,7 +161,17 @@ func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error { correlationID := ctx.Get("X-Correlation-ID", uuid.New().String()) ctx.Set("X-Correlation-ID", correlationID) + //c1, cancel := context.WithCancel(re.applicationConfig.Context) + // Use the application context as parent to ensure cancellation on app shutdown + // We'll monitor the Fiber context separately and cancel our context when the request is canceled c1, cancel := context.WithCancel(re.applicationConfig.Context) + // Monitor the Fiber context and cancel our context when it's canceled + // This ensures we respect request cancellation without causing panics + go func() { + <-ctx.Context().Done() + // Fiber context was canceled (request completed or client disconnected) + cancel() + }() // Add the correlation ID to the new context ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) diff --git a/core/http/static/chat.js b/core/http/static/chat.js index 9b10a626e967..4517b0baf5bf 100644 --- a/core/http/static/chat.js +++ b/core/http/static/chat.js @@ -27,21 +27,43 @@ SOFTWARE. */ +// Global variable to store the current AbortController +let currentAbortController = null; +let currentReader = null; + function toggleLoader(show) { - const loader = document.getElementById('loader'); const sendButton = document.getElementById('send-button'); + const stopButton = document.getElementById('stop-button'); if (show) { - loader.style.display = 'block'; sendButton.style.display = 'none'; + stopButton.style.display = 'block'; document.getElementById("input").disabled = true; } else { document.getElementById("input").disabled = false; - loader.style.display = 'none'; sendButton.style.display = 'block'; + stopButton.style.display = 'none'; + currentAbortController = null; + currentReader = null; } } +function stopRequest() { + if (currentAbortController) { + currentAbortController.abort(); + currentAbortController = null; + } + if (currentReader) { + currentReader.cancel(); + currentReader = null; + } + toggleLoader(false); + Alpine.store("chat").add( + "assistant", + `Request cancelled by user`, + ); +} + function processThinkingTags(content) { const thinkingRegex = /(.*?)<\/thinking>|(.*?)<\/think>/gs; const parts = content.split(thinkingRegex); @@ -295,8 +317,9 @@ async function promptGPT(systemPrompt, input) { let response; try { - // Create AbortController for timeout handling + // Create AbortController for timeout handling and stop button const controller = new AbortController(); + currentAbortController = controller; // Store globally so stop button can abort it const timeoutId = setTimeout(() => controller.abort(), mcpMode ? 300000 : 30000); // 5 minutes for MCP, 30 seconds for regular response = await fetch(endpoint, { @@ -311,11 +334,20 @@ async function promptGPT(systemPrompt, input) { clearTimeout(timeoutId); } catch (error) { + // Don't show error if request was aborted by user (stop button) if (error.name === 'AbortError') { - Alpine.store("chat").add( - "assistant", - `Request timeout: MCP processing is taking longer than expected. Please try again.`, - ); + // Check if this was a user-initiated abort (stop button was clicked) + // If currentAbortController is null, it means stopRequest() was called and already handled the UI + if (!currentAbortController) { + // User clicked stop button - error message already shown by stopRequest() + return; + } else { + // Timeout error (controller was aborted by timeout, not user) + Alpine.store("chat").add( + "assistant", + `Request timeout: MCP processing is taking longer than expected. Please try again.`, + ); + } } else { Alpine.store("chat").add( "assistant", @@ -323,6 +355,7 @@ async function promptGPT(systemPrompt, input) { ); } toggleLoader(false); + currentAbortController = null; return; } @@ -332,6 +365,7 @@ async function promptGPT(systemPrompt, input) { `Error: POST ${endpoint} ${response.status}`, ); toggleLoader(false); + currentAbortController = null; return; } @@ -360,10 +394,15 @@ async function promptGPT(systemPrompt, input) { // Highlight all code blocks hljs.highlightAll(); } catch (error) { - Alpine.store("chat").add( - "assistant", - `Error: Failed to parse MCP response`, - ); + // Don't show error if request was aborted by user + if (error.name !== 'AbortError' || currentAbortController) { + Alpine.store("chat").add( + "assistant", + `Error: Failed to parse MCP response`, + ); + } + } finally { + currentAbortController = null; } } else { // Handle regular streaming response @@ -376,9 +415,13 @@ async function promptGPT(systemPrompt, input) { "assistant", `Error: Failed to decode API response`, ); + toggleLoader(false); return; } + // Store reader globally so stop button can cancel it + currentReader = reader; + // Function to add content to the chat and handle DOM updates efficiently const addToChat = (token) => { const chatStore = Alpine.store("chat"); @@ -479,13 +522,20 @@ async function promptGPT(systemPrompt, input) { // Highlight all code blocks once at the end hljs.highlightAll(); } catch (error) { - Alpine.store("chat").add( - "assistant", - `Error: Failed to process stream`, - ); + // Don't show error if request was aborted by user + if (error.name !== 'AbortError' || !currentAbortController) { + Alpine.store("chat").add( + "assistant", + `Error: Failed to process stream`, + ); + } } finally { // Perform any cleanup if necessary - reader.releaseLock(); + if (reader) { + reader.releaseLock(); + } + currentReader = null; + currentAbortController = null; } } diff --git a/core/http/views/chat.html b/core/http/views/chat.html index ff9ed3ee6089..86338402f330 100644 --- a/core/http/views/chat.html +++ b/core/http/views/chat.html @@ -402,15 +402,19 @@

title="Upload text, markdown or PDF file" > - +
- - + +