Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions agent/config/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ type GitManager struct {

// FleetManager represents the Fleet ConfigManager configuration.
type FleetManager struct {
URL string `yaml:"url"`
TokenURL string `yaml:"token_url"`
Timeout *int `yaml:"timeout,omitempty"`
SkipTLS bool `yaml:"skip_tls"`
ClientID string `yaml:"client_id"`
ClientSecret string `yaml:"client_secret"`
URL string `yaml:"url"`
TokenURL string `yaml:"token_url"`
Timeout *int `yaml:"timeout,omitempty"`
SkipTLS bool `yaml:"skip_tls"`
ClientID string `yaml:"client_id"`
ClientSecret string `yaml:"client_secret"`
TokenExpiryCheckInterval *int `yaml:"token_expiry_check_interval,omitempty"` // Check interval in seconds (default: 30)
TokenReconnectBuffer *int `yaml:"token_reconnect_buffer,omitempty"` // Reconnect buffer in seconds before expiry (default: 120)
}

// Sources represents the configuration for manager sources, including cloud, local and git.
Expand Down
164 changes: 155 additions & 9 deletions agent/configmgr/fleet.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,32 @@ import (
var _ Manager = (*fleetConfigManager)(nil)

type fleetConfigManager struct {
logger *slog.Logger
connection *fleet.MQTTConnection
authTokenManager *fleet.AuthTokenManager
resetChan chan struct{}
backendState backend.StateRetriever
policyManager policymgr.PolicyManager
otlpBridge *otlpbridge.BridgeServer
logger *slog.Logger
connection *fleet.MQTTConnection
authTokenManager *fleet.AuthTokenManager
resetChan chan struct{}
reconnectChan chan struct{}
backendState backend.StateRetriever
policyManager policymgr.PolicyManager
otlpBridge *otlpbridge.BridgeServer
config config.Config
backends map[string]backend.Backend
labels map[string]string
configYaml string
connectionDetails fleet.ConnectionDetails
monitorCtx context.Context
monitorCancel context.CancelFunc
}

func newFleetConfigManager(logger *slog.Logger, pMgr policymgr.PolicyManager, backendState backend.StateRetriever) *fleetConfigManager {
resetChan := make(chan struct{}, 1)
reconnectChan := make(chan struct{}, 1)
return &fleetConfigManager{
logger: logger,
connection: fleet.NewMQTTConnection(logger, pMgr, resetChan, backendState),
connection: fleet.NewMQTTConnection(logger, pMgr, resetChan, reconnectChan, backendState),
authTokenManager: fleet.NewAuthTokenManager(logger),
resetChan: resetChan,
reconnectChan: reconnectChan,
backendState: backendState,
policyManager: pMgr,
}
Expand Down Expand Up @@ -106,6 +116,14 @@ func (fleetManager *fleetConfigManager) Start(cfg config.Config, backends map[st
if err != nil {
return fmt.Errorf("failed to convert config to safe string: %w", err)
}

// Store connection state for reconnection
fleetManager.config = cfg
fleetManager.backends = backends
fleetManager.labels = cfg.OrbAgent.Labels
fleetManager.configYaml = string(configYaml)
fleetManager.connectionDetails = connectionDetails

err = fleetManager.connection.Connect(ctx, connectionDetails, backends, cfg.OrbAgent.Labels, string(configYaml))
if err != nil {
return err
Expand Down Expand Up @@ -158,6 +176,69 @@ func (fleetManager *fleetConfigManager) Start(cfg config.Config, backends map[st
fleetManager.logger.Info("OTLP bridge bound to Fleet MQTT", slog.String("topic", topics.Ingest))
})

// Start goroutine to handle reconnect requests (JWT refresh)
go func() {
for range fleetManager.reconnectChan {
fleetManager.logger.Info("JWT refresh and reconnection requested")
if err := fleetManager.refreshAndReconnect(ctx, timeout); err != nil {
fleetManager.logger.Error("failed to refresh and reconnect", "error", err)
}
}
}()

// Start background goroutine to monitor token expiry and trigger proactive reconnection
fleetManager.monitorCtx, fleetManager.monitorCancel = context.WithCancel(context.Background())
go fleetManager.monitorTokenExpiry()

return nil
}

// refreshAndReconnect refreshes the JWT token and reconnects to MQTT
func (fleetManager *fleetConfigManager) refreshAndReconnect(ctx context.Context, timeout time.Duration) error {
// Refresh JWT token
token, err := fleetManager.authTokenManager.RefreshToken(ctx)
if err != nil {
return fmt.Errorf("failed to refresh token: %w", err)
}

// Parse new JWT claims
jwtClaims, err := fleet.ParseJWTClaims(token.AccessToken)
if err != nil {
return fmt.Errorf("failed to parse JWT claims: %w", err)
}

// Regenerate topics
topics, err := fleet.GenerateTopicsFromTemplate(jwtClaims)
if err != nil {
return fmt.Errorf("failed to generate topics: %w", err)
}

fleetManager.logger.Info("refreshed JWT and generated new topics",
"heartbeat_topic", topics.Heartbeat,
"capabilities_topic", topics.Capabilities,
"inbox_topic", topics.Inbox,
"outbox_topic", topics.Outbox)

// Update connection details
newConnectionDetails := fleet.ConnectionDetails{
MQTTURL: jwtClaims.MqttURL,
Token: token.AccessToken,
AgentID: jwtClaims.AgentID,
Topics: *topics,
ClientID: fleetManager.config.OrbAgent.ConfigManager.Sources.Fleet.ClientID,
Zone: jwtClaims.Zone,
}

// Store updated connection details
fleetManager.connectionDetails = newConnectionDetails

// Reconnect with new token
err = fleetManager.connection.Reconnect(ctx, newConnectionDetails, fleetManager.backends, fleetManager.labels, fleetManager.configYaml, timeout)
if err != nil {
return fmt.Errorf("failed to reconnect: %w", err)
}

fleetManager.logger.Info("successfully refreshed JWT and reconnected")
return nil
}

Expand All @@ -178,8 +259,73 @@ func (fleetManager *fleetConfigManager) GetContext(ctx context.Context) context.
return ctx
}

// Stop gracefully shuts down the OTLP bridge.
// monitorTokenExpiry periodically checks token expiry and triggers reconnection before token expires
func (fleetManager *fleetConfigManager) monitorTokenExpiry() {
// Check interval: default 30 seconds, configurable via config
checkInterval := 30 * time.Second
if fleetManager.config.OrbAgent.ConfigManager.Sources.Fleet.TokenExpiryCheckInterval != nil && *fleetManager.config.OrbAgent.ConfigManager.Sources.Fleet.TokenExpiryCheckInterval > 0 {
checkInterval = time.Duration(*fleetManager.config.OrbAgent.ConfigManager.Sources.Fleet.TokenExpiryCheckInterval) * time.Second
}

// Reconnect buffer: default 2 minutes before expiry, configurable via config
reconnectBuffer := 2 * time.Minute
if fleetManager.config.OrbAgent.ConfigManager.Sources.Fleet.TokenReconnectBuffer != nil && *fleetManager.config.OrbAgent.ConfigManager.Sources.Fleet.TokenReconnectBuffer > 0 {
reconnectBuffer = time.Duration(*fleetManager.config.OrbAgent.ConfigManager.Sources.Fleet.TokenReconnectBuffer) * time.Second
}

ticker := time.NewTicker(checkInterval)
defer ticker.Stop()

fleetManager.logger.Info("starting token expiry monitor",
"check_interval", checkInterval,
"reconnect_buffer", reconnectBuffer)

for {
select {
case <-fleetManager.monitorCtx.Done():
fleetManager.logger.Info("token expiry monitor stopped")
return
case <-ticker.C:
// Check if token is expired or expiring soon
if fleetManager.authTokenManager.IsTokenExpired() {
fleetManager.logger.Warn("JWT token has expired, triggering reconnection",
"expiry_time", fleetManager.authTokenManager.GetTokenExpiryTime())
select {
case fleetManager.reconnectChan <- struct{}{}:
fleetManager.logger.Debug("reconnection signal sent due to expired token")
default:
fleetManager.logger.Debug("reconnection already in progress, skipping duplicate trigger")
}
} else if fleetManager.authTokenManager.IsTokenExpiringSoon(reconnectBuffer) {
fleetManager.logger.Warn("JWT token expiring soon, triggering proactive reconnection",
"expiry_time", fleetManager.authTokenManager.GetTokenExpiryTime(),
"reconnect_buffer", reconnectBuffer)
select {
case fleetManager.reconnectChan <- struct{}{}:
fleetManager.logger.Debug("reconnection signal sent due to imminent token expiry")
default:
fleetManager.logger.Debug("reconnection already in progress, skipping duplicate trigger")
}
} else {
expiryTime := fleetManager.authTokenManager.GetTokenExpiryTime()
if !expiryTime.IsZero() {
timeUntilExpiry := time.Until(expiryTime)
fleetManager.logger.Debug("token expiry check passed",
"expiry_time", expiryTime,
"time_until_expiry", timeUntilExpiry)
}
}
}
}
}

// Stop gracefully shuts down the OTLP bridge and token expiry monitor.
func (fleetManager *fleetConfigManager) Stop(ctx context.Context) error {
// Stop token expiry monitor
if fleetManager.monitorCancel != nil {
fleetManager.monitorCancel()
}

if fleetManager.otlpBridge != nil {
if err := fleetManager.otlpBridge.Stop(ctx); err != nil {
fleetManager.logger.Error("error while stopping OTLP bridge", slog.Any("error", err))
Expand Down
102 changes: 98 additions & 4 deletions agent/configmgr/fleet/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,21 @@ import (
"net/url"
"strings"
"time"

"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
)

// AuthTokenManager manages auth tokens
type AuthTokenManager struct {
logger *slog.Logger
logger *slog.Logger
tokenURL string
skipTLS bool
timeout time.Duration
clientID string
clientSecret string
lastToken *TokenResponse
tokenExpiresAt time.Time
}

// NewAuthTokenManager creates a new AuthTokenManager
Expand Down Expand Up @@ -46,6 +56,13 @@ func (fleetManager *AuthTokenManager) GetToken(ctx context.Context, tokenURL str
return nil, fmt.Errorf("client secret cannot be empty")
}

// Store credentials for future refresh
fleetManager.tokenURL = tokenURL
fleetManager.skipTLS = skipTLS
fleetManager.timeout = timeout
fleetManager.clientID = clientID
fleetManager.clientSecret = clientSecret

fleetManager.logger.Debug("requesting access token", "token_url", tokenURL, "client_id", clientID)

scopes := []string{
Expand All @@ -60,8 +77,6 @@ func (fleetManager *AuthTokenManager) GetToken(ctx context.Context, tokenURL str
data.Set("client_secret", clientSecret)
data.Set("audience", "orb")

fleetManager.logger.Debug("sending token request", "url", tokenURL, "data", data, "client_id", clientID) //, "client_secret", clientSecret)

req, err := http.NewRequest("POST", tokenURL, bytes.NewBufferString(data.Encode()))
if err != nil {
fleetManager.logger.Error("failed to create token request", "error", err, "token_url", tokenURL)
Expand All @@ -77,7 +92,8 @@ func (fleetManager *AuthTokenManager) GetToken(ctx context.Context, tokenURL str
},
}

fleetManager.logger.Debug("sending token request", "url", tokenURL)
fleetManager.logger.Debug("sending token request", "url", tokenURL, "data", data, "client_id", clientID)

resp, err := httpClient.Do(req.WithContext(ctx))
if err != nil {
fleetManager.logger.Error("failed to send token request", "error", err, "token_url", tokenURL)
Expand Down Expand Up @@ -121,5 +137,83 @@ func (fleetManager *AuthTokenManager) GetToken(ctx context.Context, tokenURL str
"expires_in", TokenResponse.ExpiresIn,
"mqtt_url", TokenResponse.MQTTURL)

// Store token and calculate expiration time
fleetManager.lastToken = &TokenResponse

// Try to parse JWT exp claim for more accurate expiry tracking
var expiryTime time.Time
if parsedExpiry, err := parseJWTExpiry(TokenResponse.AccessToken); err == nil && !parsedExpiry.IsZero() {
// Use JWT exp claim with 5-minute buffer for safety
expiryTime = parsedExpiry.Add(-5 * time.Minute)
fleetManager.logger.Debug("using JWT exp claim for token expiry", "expiry", parsedExpiry, "buffer_applied", expiryTime)
} else if TokenResponse.ExpiresIn > 0 {
// Fallback to ExpiresIn from response (with 5-minute buffer)
expiryTime = time.Now().Add(time.Duration(TokenResponse.ExpiresIn)*time.Second - 5*time.Minute)
fleetManager.logger.Debug("using ExpiresIn for token expiry", "expires_in", TokenResponse.ExpiresIn, "buffer_applied", expiryTime)
}

fleetManager.tokenExpiresAt = expiryTime

return &TokenResponse, nil
}

// RefreshToken refreshes the auth token using stored credentials
func (fleetManager *AuthTokenManager) RefreshToken(ctx context.Context) (*TokenResponse, error) {
if fleetManager.tokenURL == "" {
return nil, fmt.Errorf("cannot refresh token: credentials not initialized")
}

fleetManager.logger.Info("refreshing JWT token")
return fleetManager.GetToken(ctx, fleetManager.tokenURL, fleetManager.skipTLS, fleetManager.timeout, fleetManager.clientID, fleetManager.clientSecret)
}

// IsTokenExpired checks if the current token is expired or will expire soon
func (fleetManager *AuthTokenManager) IsTokenExpired() bool {
if fleetManager.lastToken == nil {
return true
}
return time.Now().After(fleetManager.tokenExpiresAt)
}

// IsTokenExpiringSoon checks if the token will expire within the specified duration
func (fleetManager *AuthTokenManager) IsTokenExpiringSoon(buffer time.Duration) bool {
if fleetManager.lastToken == nil {
return true
}
if fleetManager.tokenExpiresAt.IsZero() {
return true
}
return time.Now().Add(buffer).After(fleetManager.tokenExpiresAt)
}

// GetTokenExpiryTime returns the time when the current token expires (with buffer already applied)
func (fleetManager *AuthTokenManager) GetTokenExpiryTime() time.Time {
return fleetManager.tokenExpiresAt
}

// parseJWTExpiry extracts the exp claim from a JWT token
func parseJWTExpiry(tokenString string) (time.Time, error) {
if tokenString == "" {
return time.Time{}, fmt.Errorf("empty token string")
}

// Parse the JWT token without verification
token, err := jwt.ParseSigned(tokenString, []jose.SignatureAlgorithm{jose.HS256, jose.HS384, jose.HS512, jose.RS256, jose.RS384, jose.RS512, jose.ES256, jose.ES384, jose.ES512})
if err != nil {
return time.Time{}, fmt.Errorf("failed to parse JWT token: %w", err)
}

var claims jwt.Claims

// Extract standard claims without verification
if err := token.UnsafeClaimsWithoutVerification(&claims, nil); err != nil {
return time.Time{}, fmt.Errorf("failed to extract claims from JWT: %w", err)
}

// Check if exp claim exists
if claims.Expiry == nil {
return time.Time{}, fmt.Errorf("exp claim not found in JWT token")
}

return claims.Expiry.Time(), nil
}
Loading
Loading