Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
f0c063b
[core] Support token based auth in ray dashboard UI
Nov 3, 2025
0d3316c
fix lint issues
Nov 3, 2025
7660ade
address comments
Nov 3, 2025
5234bfe
fix lint
Nov 3, 2025
b30bf04
Merge remote-tracking branch 'upstream/master' into token_auth_8
Nov 4, 2025
f61c27f
fix typo
Nov 4, 2025
2da0b9b
[core] Configure an interceptor to pass auth token in python direct g…
Nov 4, 2025
1f1897d
fix lint issues
Nov 4, 2025
6aa4ce0
reduce code duplication
Nov 4, 2025
8186c09
empty commit
Nov 4, 2025
d90ffd3
[Core] Add Service Interceptor to support token authentication in das…
Nov 5, 2025
b464cd5
[core] Token auth usability improvements
Nov 5, 2025
29f511a
address comment
Nov 5, 2025
26de367
add test_grpc_authentication_server_interceptor to BUILD.bazel
Nov 5, 2025
03e5c52
Merge branch 'token_auth_10' into token_auth_11
sampan-s-nayak Nov 5, 2025
2b5e01b
fix typo
Nov 5, 2025
f38591d
Merge branch 'token_auth_10' into token_auth_11
sampan-s-nayak Nov 5, 2025
c018203
[core] use client interceptor for adding auth token in c++ client calls
Nov 6, 2025
d66af16
fix lint issues
Nov 6, 2025
fbcfea4
separate out intererceptor code
Nov 6, 2025
d59d8ea
Merge branch 'master' into token_auth_10
edoakes Nov 10, 2025
267b440
Merge branch 'token_auth_10' into token_auth_11
sampan-s-nayak Nov 11, 2025
4e85216
Merge branch 'master' into token_auth_10
sampan-s-nayak Nov 11, 2025
e0b8b93
Merge branch 'token_auth_10' into token_auth_11
sampan-s-nayak Nov 11, 2025
214b0e1
Merge branch 'token_auth_11' into token_auth_12
sampan-s-nayak Nov 11, 2025
1240a8c
empty commit
Nov 11, 2025
c857c11
Merge branch 'token_auth_10' into token_auth_11
sampan-s-nayak Nov 11, 2025
3e473ba
address comment
Nov 11, 2025
86f62f2
Merge branch 'token_auth_11' into token_auth_12
sampan-s-nayak Nov 11, 2025
6fe7473
Merge branch 'master' into token_auth_11
edoakes Nov 11, 2025
1b0b89b
Merge branch 'master' of https://github.com/ray-project/ray into toke…
edoakes Nov 11, 2025
0a473bf
Merge branch 'token_auth_11' into token_auth_12
sampan-s-nayak Nov 11, 2025
ab8ea56
Merge branch 'master' into token_auth_11
edoakes Nov 12, 2025
790150c
Merge branch 'token_auth_11' into token_auth_12
sampan-s-nayak Nov 12, 2025
ca93ce0
Merge branch 'master' into token_auth_12
sampan-s-nayak Nov 12, 2025
99513f7
fix import after merge
Nov 12, 2025
0317fda
fix lint issues
Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def ensure_token_if_auth_enabled(
3. Generate and save a default token for new local clusters if one doesn't already exist.

Args:
system_config: Ray raises an error if you set auth_mode in system_config instead of the environment.
system_config: Ray raises an error if you set AUTH_MODE in system_config instead of the environment.
create_token_if_missing: Generate a new token if one doesn't already exist.

Raises:
Expand All @@ -79,11 +79,11 @@ def ensure_token_if_auth_enabled(
if get_authentication_mode() != AuthenticationMode.TOKEN:
if (
system_config
and "auth_mode" in system_config
and system_config["auth_mode"] != "disabled"
and "AUTH_MODE" in system_config
and system_config["AUTH_MODE"] != "disabled"
):
raise RuntimeError(
"Set authentication mode can only be set with the `RAY_auth_mode` environment variable, not using the system_config."
"Set authentication mode can only be set with the `RAY_AUTH_MODE` environment variable, not using the system_config."
)
return

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""gRPC client interceptor for token-based authentication."""

import logging
from collections import namedtuple
from typing import Tuple

import grpc
from grpc import aio as aiogrpc

from ray._raylet import AuthenticationTokenLoader

logger = logging.getLogger(__name__)


# Named tuple to hold client call details
_ClientCallDetails = namedtuple(
"_ClientCallDetails",
("method", "timeout", "metadata", "credentials", "wait_for_ready", "compression"),
)


def _get_authentication_metadata_tuple() -> Tuple[Tuple[str, str], ...]:
"""Get gRPC metadata tuple for authentication. Currently only supported for token authentication.
Returns:
tuple: Empty tuple or ((AUTHORIZATION_HEADER_NAME, "Bearer <token>"),)
"""
token_loader = AuthenticationTokenLoader.instance()
if not token_loader.has_token():
return ()

headers = token_loader.get_token_for_http_header()

# Convert HTTP header dict to gRPC metadata tuple
# gRPC expects: (("key", "value"), ...)
return tuple((k, v) for k, v in headers.items())


class AuthenticationMetadataClientInterceptor(
grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor,
grpc.StreamStreamClientInterceptor,
):
"""Synchronous gRPC client interceptor that adds authentication metadata."""

def _intercept_call_details(self, client_call_details):
"""Helper method to add authentication metadata to client call details."""
metadata = list(client_call_details.metadata or [])
metadata.extend(_get_authentication_metadata_tuple())

return _ClientCallDetails(
method=client_call_details.method,
timeout=client_call_details.timeout,
metadata=metadata,
credentials=client_call_details.credentials,
wait_for_ready=getattr(client_call_details, "wait_for_ready", None),
compression=getattr(client_call_details, "compression", None),
)

def intercept_unary_unary(self, continuation, client_call_details, request):
new_details = self._intercept_call_details(client_call_details)
return continuation(new_details, request)

def intercept_unary_stream(self, continuation, client_call_details, request):
new_details = self._intercept_call_details(client_call_details)
return continuation(new_details, request)

def intercept_stream_unary(
self, continuation, client_call_details, request_iterator
):
new_details = self._intercept_call_details(client_call_details)
return continuation(new_details, request_iterator)

def intercept_stream_stream(
self, continuation, client_call_details, request_iterator
):
new_details = self._intercept_call_details(client_call_details)
return continuation(new_details, request_iterator)


class AsyncAuthenticationMetadataClientInterceptor(
aiogrpc.UnaryUnaryClientInterceptor,
aiogrpc.UnaryStreamClientInterceptor,
aiogrpc.StreamUnaryClientInterceptor,
aiogrpc.StreamStreamClientInterceptor,
):
"""Async gRPC client interceptor that adds authentication metadata."""

def _intercept_call_details(self, client_call_details):
"""Helper method to add authentication metadata to client call details."""
metadata = list(client_call_details.metadata or [])
metadata.extend(_get_authentication_metadata_tuple())

return _ClientCallDetails(
method=client_call_details.method,
timeout=client_call_details.timeout,
metadata=metadata,
credentials=client_call_details.credentials,
wait_for_ready=getattr(client_call_details, "wait_for_ready", None),
compression=getattr(client_call_details, "compression", None),
)

async def intercept_unary_unary(self, continuation, client_call_details, request):
new_details = self._intercept_call_details(client_call_details)
return await continuation(new_details, request)

async def intercept_unary_stream(self, continuation, client_call_details, request):
new_details = self._intercept_call_details(client_call_details)
return await continuation(new_details, request)

async def intercept_stream_unary(
self, continuation, client_call_details, request_iterator
):
new_details = self._intercept_call_details(client_call_details)
return await continuation(new_details, request_iterator)

async def intercept_stream_stream(
self, continuation, client_call_details, request_iterator
):
new_details = self._intercept_call_details(client_call_details)
return await continuation(new_details, request_iterator)
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""gRPC server interceptor for token-based authentication."""

import logging
from typing import Awaitable, Callable

import grpc
from grpc import aio as aiogrpc

from ray._private.authentication.authentication_constants import (
AUTHORIZATION_HEADER_NAME,
)
from ray._private.authentication.authentication_utils import (
is_token_auth_enabled,
validate_request_token,
)

logger = logging.getLogger(__name__)


class AsyncAuthenticationServerInterceptor(aiogrpc.ServerInterceptor):
"""Async gRPC server interceptor that validates authentication tokens.

This interceptor checks the "authorization" metadata header for a valid
Bearer token when token authentication is enabled via RAY_AUTH_MODE=token.
If the token is missing or invalid, the request is rejected with UNAUTHENTICATED status.
"""

def _validate_authentication(self, metadata: tuple) -> bool:
"""Validate authentication token from gRPC metadata.

Args:
metadata: gRPC metadata tuple of (key, value) pairs

Returns:
True if authentication succeeds or is not required, False otherwise
"""
# If token auth is not enabled, allow all requests
if not is_token_auth_enabled():
return True

# Extract authorization header from metadata
auth_header = None
for key, value in metadata:
if key.lower() == AUTHORIZATION_HEADER_NAME:
auth_header = value
break

if not auth_header:
logger.warning(
"Authentication required but no authorization header provided"
)
return False

# Validate the token format and value
# validate_request_token returns bool (True if valid, False otherwise)
return validate_request_token(auth_header)

async def intercept_service(
self,
continuation: Callable[
[grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler]
],
handler_call_details: grpc.HandlerCallDetails,
) -> grpc.RpcMethodHandler:
"""Intercept service calls to validate authentication.

This method is called once per RPC to get the handler. We wrap the handler
to validate authentication before executing the actual RPC method.
"""
# Get the actual handler
handler = await continuation(handler_call_details)

if handler is None:
return None

# Wrap the RPC behavior with authentication check
def wrap_rpc_behavior(behavior):
"""Wrap an RPC method to validate authentication first."""
if behavior is None:
return None

async def wrapped(request_or_iterator, context):
if not self._validate_authentication(context.invocation_metadata()):
await context.abort(
grpc.StatusCode.UNAUTHENTICATED,
"Invalid or missing authentication token",
)
return await behavior(request_or_iterator, context)

return wrapped

# Create a wrapper class that implements RpcMethodHandler interface
class AuthenticatedHandler:
"""Wrapper handler that validates authentication."""

def __init__(self, original_handler, wrapper_func):
self._original = original_handler
self._wrap = wrapper_func

@property
def request_streaming(self):
return self._original.request_streaming

@property
def response_streaming(self):
return self._original.response_streaming

@property
def request_deserializer(self):
return self._original.request_deserializer

@property
def response_serializer(self):
return self._original.response_serializer

@property
def unary_unary(self):
return self._wrap(self._original.unary_unary)

@property
def unary_stream(self):
return self._wrap(self._original.unary_stream)

@property
def stream_unary(self):
return self._wrap(self._original.stream_unary)

@property
def stream_stream(self):
return self._wrap(self._original.stream_stream)

return AuthenticatedHandler(handler, wrap_rpc_behavior)
23 changes: 19 additions & 4 deletions python/ray/_private/authentication/http_token_authentication.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
import logging
from types import ModuleType
from typing import Dict, Optional
from typing import Dict, List, Optional

from ray._private.authentication import authentication_constants
from ray.dashboard import authentication_utils as auth_utils
from ray._private.authentication import (
authentication_constants,
authentication_utils as auth_utils,
)

logger = logging.getLogger(__name__)


def get_token_auth_middleware(aiohttp_module: ModuleType):
def get_token_auth_middleware(
aiohttp_module: ModuleType,
whitelisted_exact_paths: Optional[List[str]] = None,
whitelisted_path_prefixes: Optional[List[str]] = None,
):
"""Internal helper to create token auth middleware with provided modules.

Args:
aiohttp_module: The aiohttp module to use
whitelisted_exact_paths: List of exact paths that don't require authentication
whitelisted_path_prefixes: List of path prefixes that don't require authentication
Returns:
An aiohttp middleware function
"""
Expand All @@ -28,6 +36,13 @@ async def token_auth_middleware(request, handler):
if not auth_utils.is_token_auth_enabled():
return await handler(request)

# skip authentication for whitelisted paths
if (whitelisted_exact_paths and request.path in whitelisted_exact_paths) or (
whitelisted_path_prefixes
and request.path.startswith(tuple(whitelisted_path_prefixes))
):
return await handler(request)

auth_header = request.headers.get(
authentication_constants.AUTHORIZATION_HEADER_NAME, ""
)
Expand Down
34 changes: 31 additions & 3 deletions python/ray/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,7 @@ def init_grpc_channel(
import grpc
from grpc import aio as aiogrpc

from ray._private.authentication import authentication_utils
from ray._private.tls_utils import load_certs_from_env

grpc_module = aiogrpc if asynchronous else grpc
Expand All @@ -1040,16 +1041,43 @@ def init_grpc_channel(
)
options = options_dict.items()

if os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"):
# Build interceptors list
interceptors = []
if authentication_utils.is_token_auth_enabled():
from ray._private.authentication.grpc_authentication_client_interceptor import (
AsyncAuthenticationMetadataClientInterceptor,
AuthenticationMetadataClientInterceptor,
)

if asynchronous:
interceptors.append(AsyncAuthenticationMetadataClientInterceptor())
else:
interceptors.append(AuthenticationMetadataClientInterceptor())

# Create channel with TLS if enabled
use_tls = os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true")
if use_tls:
server_cert_chain, private_key, ca_cert = load_certs_from_env()
credentials = grpc.ssl_channel_credentials(
certificate_chain=server_cert_chain,
private_key=private_key,
root_certificates=ca_cert,
)
channel = grpc_module.secure_channel(address, credentials, options=options)
channel_creator = grpc_module.secure_channel
base_args = (address, credentials)
else:
channel_creator = grpc_module.insecure_channel
base_args = (address,)

# Create channel (async channels get interceptors in constructor, sync via intercept_channel)
if asynchronous:
channel = channel_creator(
*base_args, options=options, interceptors=interceptors
)
else:
channel = grpc_module.insecure_channel(address, options=options)
channel = channel_creator(*base_args, options=options)
if interceptors:
channel = grpc.intercept_channel(channel, *interceptors)

return channel

Expand Down
Loading