|
15 | 15 | //***************************************************************************** |
16 | 16 | #include "http_rest_api_handler.hpp" |
17 | 17 |
|
| 18 | +#include <algorithm> |
18 | 19 | #include <cctype> |
19 | 20 | #include <iomanip> |
20 | 21 | #include <memory> |
@@ -123,7 +124,8 @@ const std::string HttpRestApiHandler::v3_RegexExp = |
123 | 124 |
|
124 | 125 | const std::string HttpRestApiHandler::metricsRegexExp = R"((.?)\/metrics(\?(.*))?)"; |
125 | 126 |
|
126 | | -HttpRestApiHandler::HttpRestApiHandler(ovms::Server& ovmsServer, int timeout_in_ms) : |
| 127 | +HttpRestApiHandler::HttpRestApiHandler(ovms::Server& ovmsServer, int timeout_in_ms, const std::string& apiKey) : |
| 128 | + apiKey(apiKey), |
127 | 129 | predictionRegex(predictionRegexExp), |
128 | 130 | modelstatusRegex(modelstatusRegexExp), |
129 | 131 | configReloadRegex(configReloadRegexExp), |
@@ -668,14 +670,36 @@ Status HttpRestApiHandler::processListModelsRequest(std::string& response) { |
668 | 670 | return StatusCode::OK; |
669 | 671 | } |
670 | 672 |
|
| 673 | +bool HttpRestApiHandler::isAuthorized(const std::unordered_map<std::string, std::string>& headers, const std::string& apiKey) { |
| 674 | + std::unordered_map<std::string, std::string> lowercaseHeaders; |
| 675 | + for (const auto& [key, value] : headers) { |
| 676 | + std::string lowercaseKey = key; |
| 677 | + std::transform(lowercaseKey.begin(), lowercaseKey.end(), lowercaseKey.begin(), ::tolower); |
| 678 | + if (lowercaseKey == "authorization") { |
| 679 | + if (value == "Bearer " + apiKey) { |
| 680 | + return true; |
| 681 | + } else { |
| 682 | + SPDLOG_DEBUG("Unauthorized request - invalid API key provided."); |
| 683 | + return false; |
| 684 | + } |
| 685 | + } |
| 686 | + } |
| 687 | + SPDLOG_DEBUG("Unauthorized request - missing API key"); |
| 688 | + return false; |
| 689 | +} |
| 690 | + |
671 | 691 | Status HttpRestApiHandler::processV3(const std::string_view uri, const HttpRequestComponents& request_components, std::string& response, const std::string& request_body, std::shared_ptr<HttpAsyncWriter> serverReaderWriter, std::shared_ptr<MultiPartParser> multiPartParser) { |
672 | 692 | #if (MEDIAPIPE_DISABLE == 0) |
673 | 693 | OVMS_PROFILE_FUNCTION(); |
674 | 694 |
|
675 | 695 | HttpPayload request; |
676 | 696 | std::string modelName; |
677 | 697 | bool streamFieldVal = false; |
678 | | - |
| 698 | + if (!this->apiKey.empty()) { |
| 699 | + if (!isAuthorized(request_components.headers, this->apiKey)) { |
| 700 | + return StatusCode::UNAUTHORIZED; |
| 701 | + } |
| 702 | + } |
679 | 703 | auto status = createV3HttpPayload(uri, request_components, response, request_body, serverReaderWriter, std::move(multiPartParser), request, modelName, streamFieldVal); |
680 | 704 | if (!status.ok()) { |
681 | 705 | SPDLOG_DEBUG("Failed to create V3 payload: {}", status.string()); |
|
0 commit comments