Skip to content

Commit 5fcf0d3

Browse files
authored
feat: refreshes the token and reconnects if publishing fails (#228)
1 parent bf96ca6 commit 5fcf0d3

File tree

12 files changed

+1085
-90
lines changed

12 files changed

+1085
-90
lines changed

agent/config/types.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,14 @@ type GitManager struct {
3535

3636
// FleetManager represents the Fleet ConfigManager configuration.
3737
type FleetManager struct {
38-
URL string `yaml:"url"`
39-
TokenURL string `yaml:"token_url"`
40-
Timeout *int `yaml:"timeout,omitempty"`
41-
SkipTLS bool `yaml:"skip_tls"`
42-
ClientID string `yaml:"client_id"`
43-
ClientSecret string `yaml:"client_secret"`
38+
URL string `yaml:"url"`
39+
TokenURL string `yaml:"token_url"`
40+
Timeout *int `yaml:"timeout,omitempty"`
41+
SkipTLS bool `yaml:"skip_tls"`
42+
ClientID string `yaml:"client_id"`
43+
ClientSecret string `yaml:"client_secret"`
44+
TokenExpiryCheckInterval *int `yaml:"token_expiry_check_interval,omitempty"` // Check interval in seconds (default: 30)
45+
TokenReconnectBuffer *int `yaml:"token_reconnect_buffer,omitempty"` // Reconnect buffer in seconds before expiry (default: 120)
4446
}
4547

4648
// Sources represents the configuration for manager sources, including cloud, local and git.

agent/configmgr/fleet.go

Lines changed: 155 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,32 @@ import (
2020
var _ Manager = (*fleetConfigManager)(nil)
2121

2222
type fleetConfigManager struct {
23-
logger *slog.Logger
24-
connection *fleet.MQTTConnection
25-
authTokenManager *fleet.AuthTokenManager
26-
resetChan chan struct{}
27-
backendState backend.StateRetriever
28-
policyManager policymgr.PolicyManager
29-
otlpBridge *otlpbridge.BridgeServer
23+
logger *slog.Logger
24+
connection *fleet.MQTTConnection
25+
authTokenManager *fleet.AuthTokenManager
26+
resetChan chan struct{}
27+
reconnectChan chan struct{}
28+
backendState backend.StateRetriever
29+
policyManager policymgr.PolicyManager
30+
otlpBridge *otlpbridge.BridgeServer
31+
config config.Config
32+
backends map[string]backend.Backend
33+
labels map[string]string
34+
configYaml string
35+
connectionDetails fleet.ConnectionDetails
36+
monitorCtx context.Context
37+
monitorCancel context.CancelFunc
3038
}
3139

3240
func newFleetConfigManager(logger *slog.Logger, pMgr policymgr.PolicyManager, backendState backend.StateRetriever) *fleetConfigManager {
3341
resetChan := make(chan struct{}, 1)
42+
reconnectChan := make(chan struct{}, 1)
3443
return &fleetConfigManager{
3544
logger: logger,
36-
connection: fleet.NewMQTTConnection(logger, pMgr, resetChan, backendState),
45+
connection: fleet.NewMQTTConnection(logger, pMgr, resetChan, reconnectChan, backendState),
3746
authTokenManager: fleet.NewAuthTokenManager(logger),
3847
resetChan: resetChan,
48+
reconnectChan: reconnectChan,
3949
backendState: backendState,
4050
policyManager: pMgr,
4151
}
@@ -106,6 +116,14 @@ func (fleetManager *fleetConfigManager) Start(cfg config.Config, backends map[st
106116
if err != nil {
107117
return fmt.Errorf("failed to convert config to safe string: %w", err)
108118
}
119+
120+
// Store connection state for reconnection
121+
fleetManager.config = cfg
122+
fleetManager.backends = backends
123+
fleetManager.labels = cfg.OrbAgent.Labels
124+
fleetManager.configYaml = string(configYaml)
125+
fleetManager.connectionDetails = connectionDetails
126+
109127
err = fleetManager.connection.Connect(ctx, connectionDetails, backends, cfg.OrbAgent.Labels, string(configYaml))
110128
if err != nil {
111129
return err
@@ -158,6 +176,69 @@ func (fleetManager *fleetConfigManager) Start(cfg config.Config, backends map[st
158176
fleetManager.logger.Info("OTLP bridge bound to Fleet MQTT", slog.String("topic", topics.Ingest))
159177
})
160178

179+
// Start goroutine to handle reconnect requests (JWT refresh)
180+
go func() {
181+
for range fleetManager.reconnectChan {
182+
fleetManager.logger.Info("JWT refresh and reconnection requested")
183+
if err := fleetManager.refreshAndReconnect(ctx, timeout); err != nil {
184+
fleetManager.logger.Error("failed to refresh and reconnect", "error", err)
185+
}
186+
}
187+
}()
188+
189+
// Start background goroutine to monitor token expiry and trigger proactive reconnection
190+
fleetManager.monitorCtx, fleetManager.monitorCancel = context.WithCancel(context.Background())
191+
go fleetManager.monitorTokenExpiry()
192+
193+
return nil
194+
}
195+
196+
// refreshAndReconnect refreshes the JWT token and reconnects to MQTT
197+
func (fleetManager *fleetConfigManager) refreshAndReconnect(ctx context.Context, timeout time.Duration) error {
198+
// Refresh JWT token
199+
token, err := fleetManager.authTokenManager.RefreshToken(ctx)
200+
if err != nil {
201+
return fmt.Errorf("failed to refresh token: %w", err)
202+
}
203+
204+
// Parse new JWT claims
205+
jwtClaims, err := fleet.ParseJWTClaims(token.AccessToken)
206+
if err != nil {
207+
return fmt.Errorf("failed to parse JWT claims: %w", err)
208+
}
209+
210+
// Regenerate topics
211+
topics, err := fleet.GenerateTopicsFromTemplate(jwtClaims)
212+
if err != nil {
213+
return fmt.Errorf("failed to generate topics: %w", err)
214+
}
215+
216+
fleetManager.logger.Info("refreshed JWT and generated new topics",
217+
"heartbeat_topic", topics.Heartbeat,
218+
"capabilities_topic", topics.Capabilities,
219+
"inbox_topic", topics.Inbox,
220+
"outbox_topic", topics.Outbox)
221+
222+
// Update connection details
223+
newConnectionDetails := fleet.ConnectionDetails{
224+
MQTTURL: jwtClaims.MqttURL,
225+
Token: token.AccessToken,
226+
AgentID: jwtClaims.AgentID,
227+
Topics: *topics,
228+
ClientID: fleetManager.config.OrbAgent.ConfigManager.Sources.Fleet.ClientID,
229+
Zone: jwtClaims.Zone,
230+
}
231+
232+
// Store updated connection details
233+
fleetManager.connectionDetails = newConnectionDetails
234+
235+
// Reconnect with new token
236+
err = fleetManager.connection.Reconnect(ctx, newConnectionDetails, fleetManager.backends, fleetManager.labels, fleetManager.configYaml, timeout)
237+
if err != nil {
238+
return fmt.Errorf("failed to reconnect: %w", err)
239+
}
240+
241+
fleetManager.logger.Info("successfully refreshed JWT and reconnected")
161242
return nil
162243
}
163244

@@ -178,8 +259,73 @@ func (fleetManager *fleetConfigManager) GetContext(ctx context.Context) context.
178259
return ctx
179260
}
180261

181-
// Stop gracefully shuts down the OTLP bridge.
262+
// monitorTokenExpiry periodically checks token expiry and triggers reconnection before token expires
263+
func (fleetManager *fleetConfigManager) monitorTokenExpiry() {
264+
// Check interval: default 30 seconds, configurable via config
265+
checkInterval := 30 * time.Second
266+
if fleetManager.config.OrbAgent.ConfigManager.Sources.Fleet.TokenExpiryCheckInterval != nil && *fleetManager.config.OrbAgent.ConfigManager.Sources.Fleet.TokenExpiryCheckInterval > 0 {
267+
checkInterval = time.Duration(*fleetManager.config.OrbAgent.ConfigManager.Sources.Fleet.TokenExpiryCheckInterval) * time.Second
268+
}
269+
270+
// Reconnect buffer: default 2 minutes before expiry, configurable via config
271+
reconnectBuffer := 2 * time.Minute
272+
if fleetManager.config.OrbAgent.ConfigManager.Sources.Fleet.TokenReconnectBuffer != nil && *fleetManager.config.OrbAgent.ConfigManager.Sources.Fleet.TokenReconnectBuffer > 0 {
273+
reconnectBuffer = time.Duration(*fleetManager.config.OrbAgent.ConfigManager.Sources.Fleet.TokenReconnectBuffer) * time.Second
274+
}
275+
276+
ticker := time.NewTicker(checkInterval)
277+
defer ticker.Stop()
278+
279+
fleetManager.logger.Info("starting token expiry monitor",
280+
"check_interval", checkInterval,
281+
"reconnect_buffer", reconnectBuffer)
282+
283+
for {
284+
select {
285+
case <-fleetManager.monitorCtx.Done():
286+
fleetManager.logger.Info("token expiry monitor stopped")
287+
return
288+
case <-ticker.C:
289+
// Check if token is expired or expiring soon
290+
if fleetManager.authTokenManager.IsTokenExpired() {
291+
fleetManager.logger.Warn("JWT token has expired, triggering reconnection",
292+
"expiry_time", fleetManager.authTokenManager.GetTokenExpiryTime())
293+
select {
294+
case fleetManager.reconnectChan <- struct{}{}:
295+
fleetManager.logger.Debug("reconnection signal sent due to expired token")
296+
default:
297+
fleetManager.logger.Debug("reconnection already in progress, skipping duplicate trigger")
298+
}
299+
} else if fleetManager.authTokenManager.IsTokenExpiringSoon(reconnectBuffer) {
300+
fleetManager.logger.Warn("JWT token expiring soon, triggering proactive reconnection",
301+
"expiry_time", fleetManager.authTokenManager.GetTokenExpiryTime(),
302+
"reconnect_buffer", reconnectBuffer)
303+
select {
304+
case fleetManager.reconnectChan <- struct{}{}:
305+
fleetManager.logger.Debug("reconnection signal sent due to imminent token expiry")
306+
default:
307+
fleetManager.logger.Debug("reconnection already in progress, skipping duplicate trigger")
308+
}
309+
} else {
310+
expiryTime := fleetManager.authTokenManager.GetTokenExpiryTime()
311+
if !expiryTime.IsZero() {
312+
timeUntilExpiry := time.Until(expiryTime)
313+
fleetManager.logger.Debug("token expiry check passed",
314+
"expiry_time", expiryTime,
315+
"time_until_expiry", timeUntilExpiry)
316+
}
317+
}
318+
}
319+
}
320+
}
321+
322+
// Stop gracefully shuts down the OTLP bridge and token expiry monitor.
182323
func (fleetManager *fleetConfigManager) Stop(ctx context.Context) error {
324+
// Stop token expiry monitor
325+
if fleetManager.monitorCancel != nil {
326+
fleetManager.monitorCancel()
327+
}
328+
183329
if fleetManager.otlpBridge != nil {
184330
if err := fleetManager.otlpBridge.Stop(ctx); err != nil {
185331
fleetManager.logger.Error("error while stopping OTLP bridge", slog.Any("error", err))

agent/configmgr/fleet/auth.go

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,21 @@ import (
1212
"net/url"
1313
"strings"
1414
"time"
15+
16+
"github.com/go-jose/go-jose/v4"
17+
"github.com/go-jose/go-jose/v4/jwt"
1518
)
1619

1720
// AuthTokenManager manages auth tokens
1821
type AuthTokenManager struct {
19-
logger *slog.Logger
22+
logger *slog.Logger
23+
tokenURL string
24+
skipTLS bool
25+
timeout time.Duration
26+
clientID string
27+
clientSecret string
28+
lastToken *TokenResponse
29+
tokenExpiresAt time.Time
2030
}
2131

2232
// NewAuthTokenManager creates a new AuthTokenManager
@@ -46,6 +56,13 @@ func (fleetManager *AuthTokenManager) GetToken(ctx context.Context, tokenURL str
4656
return nil, fmt.Errorf("client secret cannot be empty")
4757
}
4858

59+
// Store credentials for future refresh
60+
fleetManager.tokenURL = tokenURL
61+
fleetManager.skipTLS = skipTLS
62+
fleetManager.timeout = timeout
63+
fleetManager.clientID = clientID
64+
fleetManager.clientSecret = clientSecret
65+
4966
fleetManager.logger.Debug("requesting access token", "token_url", tokenURL, "client_id", clientID)
5067

5168
scopes := []string{
@@ -60,8 +77,6 @@ func (fleetManager *AuthTokenManager) GetToken(ctx context.Context, tokenURL str
6077
data.Set("client_secret", clientSecret)
6178
data.Set("audience", "orb")
6279

63-
fleetManager.logger.Debug("sending token request", "url", tokenURL, "data", data, "client_id", clientID) //, "client_secret", clientSecret)
64-
6580
req, err := http.NewRequest("POST", tokenURL, bytes.NewBufferString(data.Encode()))
6681
if err != nil {
6782
fleetManager.logger.Error("failed to create token request", "error", err, "token_url", tokenURL)
@@ -77,7 +92,8 @@ func (fleetManager *AuthTokenManager) GetToken(ctx context.Context, tokenURL str
7792
},
7893
}
7994

80-
fleetManager.logger.Debug("sending token request", "url", tokenURL)
95+
fleetManager.logger.Debug("sending token request", "url", tokenURL, "data", data, "client_id", clientID)
96+
8197
resp, err := httpClient.Do(req.WithContext(ctx))
8298
if err != nil {
8399
fleetManager.logger.Error("failed to send token request", "error", err, "token_url", tokenURL)
@@ -121,5 +137,83 @@ func (fleetManager *AuthTokenManager) GetToken(ctx context.Context, tokenURL str
121137
"expires_in", TokenResponse.ExpiresIn,
122138
"mqtt_url", TokenResponse.MQTTURL)
123139

140+
// Store token and calculate expiration time
141+
fleetManager.lastToken = &TokenResponse
142+
143+
// Try to parse JWT exp claim for more accurate expiry tracking
144+
var expiryTime time.Time
145+
if parsedExpiry, err := parseJWTExpiry(TokenResponse.AccessToken); err == nil && !parsedExpiry.IsZero() {
146+
// Use JWT exp claim with 5-minute buffer for safety
147+
expiryTime = parsedExpiry.Add(-5 * time.Minute)
148+
fleetManager.logger.Debug("using JWT exp claim for token expiry", "expiry", parsedExpiry, "buffer_applied", expiryTime)
149+
} else if TokenResponse.ExpiresIn > 0 {
150+
// Fallback to ExpiresIn from response (with 5-minute buffer)
151+
expiryTime = time.Now().Add(time.Duration(TokenResponse.ExpiresIn)*time.Second - 5*time.Minute)
152+
fleetManager.logger.Debug("using ExpiresIn for token expiry", "expires_in", TokenResponse.ExpiresIn, "buffer_applied", expiryTime)
153+
}
154+
155+
fleetManager.tokenExpiresAt = expiryTime
156+
124157
return &TokenResponse, nil
125158
}
159+
160+
// RefreshToken refreshes the auth token using stored credentials
161+
func (fleetManager *AuthTokenManager) RefreshToken(ctx context.Context) (*TokenResponse, error) {
162+
if fleetManager.tokenURL == "" {
163+
return nil, fmt.Errorf("cannot refresh token: credentials not initialized")
164+
}
165+
166+
fleetManager.logger.Info("refreshing JWT token")
167+
return fleetManager.GetToken(ctx, fleetManager.tokenURL, fleetManager.skipTLS, fleetManager.timeout, fleetManager.clientID, fleetManager.clientSecret)
168+
}
169+
170+
// IsTokenExpired checks if the current token is expired or will expire soon
171+
func (fleetManager *AuthTokenManager) IsTokenExpired() bool {
172+
if fleetManager.lastToken == nil {
173+
return true
174+
}
175+
return time.Now().After(fleetManager.tokenExpiresAt)
176+
}
177+
178+
// IsTokenExpiringSoon checks if the token will expire within the specified duration
179+
func (fleetManager *AuthTokenManager) IsTokenExpiringSoon(buffer time.Duration) bool {
180+
if fleetManager.lastToken == nil {
181+
return true
182+
}
183+
if fleetManager.tokenExpiresAt.IsZero() {
184+
return true
185+
}
186+
return time.Now().Add(buffer).After(fleetManager.tokenExpiresAt)
187+
}
188+
189+
// GetTokenExpiryTime returns the time when the current token expires (with buffer already applied)
190+
func (fleetManager *AuthTokenManager) GetTokenExpiryTime() time.Time {
191+
return fleetManager.tokenExpiresAt
192+
}
193+
194+
// parseJWTExpiry extracts the exp claim from a JWT token
195+
func parseJWTExpiry(tokenString string) (time.Time, error) {
196+
if tokenString == "" {
197+
return time.Time{}, fmt.Errorf("empty token string")
198+
}
199+
200+
// Parse the JWT token without verification
201+
token, err := jwt.ParseSigned(tokenString, []jose.SignatureAlgorithm{jose.HS256, jose.HS384, jose.HS512, jose.RS256, jose.RS384, jose.RS512, jose.ES256, jose.ES384, jose.ES512})
202+
if err != nil {
203+
return time.Time{}, fmt.Errorf("failed to parse JWT token: %w", err)
204+
}
205+
206+
var claims jwt.Claims
207+
208+
// Extract standard claims without verification
209+
if err := token.UnsafeClaimsWithoutVerification(&claims, nil); err != nil {
210+
return time.Time{}, fmt.Errorf("failed to extract claims from JWT: %w", err)
211+
}
212+
213+
// Check if exp claim exists
214+
if claims.Expiry == nil {
215+
return time.Time{}, fmt.Errorf("exp claim not found in JWT token")
216+
}
217+
218+
return claims.Expiry.Time(), nil
219+
}

0 commit comments

Comments
 (0)