diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index cbd3ea024..abbba09f9 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -106,6 +106,7 @@ var ( certPath = flag.String("cert-path", runserver.DefaultCertPath, "The path to the certificate for secure serving. The certificate and private key files "+ "are assumed to be named tls.crt and tls.key, respectively. If not set, and secureServing is enabled, "+ "then a self-signed certificate is used.") + enableCertReload = flag.Bool("enable-cert-reload", runserver.DefaultCertReload, "Enables certificate reloading of the certificates specified in --cert-path") // metric flags totalQueuedRequestsMetric = flag.String("total-queued-requests-metric", runserver.DefaultTotalQueuedRequestsMetric, "Prometheus metric for the number of queued requests.") kvCacheUsagePercentageMetric = flag.String("kv-cache-usage-percentage-metric", runserver.DefaultKvCacheUsagePercentageMetric, "Prometheus metric for the fraction of KV-cache blocks currently in use (from 0 to 1).") @@ -345,6 +346,7 @@ func (r *Runner) Run(ctx context.Context) error { SecureServing: *secureServing, HealthChecking: *healthChecking, CertPath: *certPath, + EnableCertReload: *enableCertReload, RefreshPrometheusMetricsInterval: *refreshPrometheusMetricsInterval, MetricsStalenessThreshold: *metricsStalenessThreshold, Director: director, diff --git a/go.mod b/go.mod index 097852877..86940bf53 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 github.com/elastic/crd-ref-docs v0.2.0 github.com/envoyproxy/go-control-plane/envoy v1.35.0 + github.com/fsnotify/fsnotify v1.9.0 github.com/go-logr/logr v1.4.3 github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 @@ -58,7 +59,6 @@ require ( github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect github.com/evanphx/json-patch/v5 v5.9.11 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/fxamacker/cbor/v2 v2.9.0 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/zapr v1.3.0 // indirect diff --git a/pkg/common/certs.go b/pkg/common/certs.go new file mode 100644 index 000000000..ead1107dc --- /dev/null +++ b/pkg/common/certs.go @@ -0,0 +1,103 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package common + +import ( + "context" + "crypto/tls" + "fmt" + "sync/atomic" + "time" + + "github.com/fsnotify/fsnotify" + "sigs.k8s.io/controller-runtime/pkg/log" + + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +// debounceDelay wait for events to settle before reloading +const debounceDelay = 250 * time.Millisecond + +type CertReloader struct { + cert *atomic.Pointer[tls.Certificate] +} + +func NewCertReloader(ctx context.Context, path string, init *tls.Certificate) (*CertReloader, error) { + certPtr := &atomic.Pointer[tls.Certificate]{} + certPtr.Store(init) + + w, err := fsnotify.NewWatcher() + if err != nil { + return nil, fmt.Errorf("failed to create cert watcher: %w", err) + } + + logger := log.FromContext(ctx). + WithName("cert-reloader"). + WithValues("path", path) + traceLogger := logger.V(logutil.TRACE) + + if err := w.Add(path); err != nil { + _ = w.Close() // Clean up watcher before returning + return nil, fmt.Errorf("failed to watch %q: %w", path, err) + } + + go func() { + defer w.Close() + + var debounceTimer *time.Timer + + for { + select { + case ev := <-w.Events: + traceLogger.Info("Cert changed", "event", ev) + + if ev.Op&(fsnotify.Write|fsnotify.Create) == 0 { + continue + } + + // Debounce: reset the timer if we get another event + if debounceTimer != nil { + debounceTimer.Stop() + } + + debounceTimer = time.AfterFunc(debounceDelay, func() { + // This runs after the delay with no new events + cert, err := tls.LoadX509KeyPair(path+"/tls.crt", path+"/tls.key") + if err != nil { + logger.Error(err, "Failed to reload TLS certificate") + return + } + certPtr.Store(&cert) + traceLogger.Info("Reloaded TLS certificate") + }) + + case err := <-w.Errors: + if err != nil { + logger.Error(err, "cert watcher failed") + } + case <-ctx.Done(): + return + } + } + }() + + return &CertReloader{cert: certPtr}, nil +} + +func (r *CertReloader) Get() *tls.Certificate { + return r.cert.Load() +} diff --git a/pkg/common/certs_test.go b/pkg/common/certs_test.go new file mode 100644 index 000000000..0a51a22b4 --- /dev/null +++ b/pkg/common/certs_test.go @@ -0,0 +1,350 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package common + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "os" + "path/filepath" + "testing" + "time" + + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +func TestCertReloader_InitialLoad(t *testing.T) { + t.Parallel() + + // Create initial certificate + certPEM1, keyPEM1, err := createTestCertificate(1) + if err != nil { + t.Fatalf("failed to create test certificate: %v", err) + } + + // Setup K8s-style directory + baseDir := setupK8sStyleCertDir(t, certPEM1, keyPEM1) + + // Load the initial certificate + cert, err := tls.LoadX509KeyPair(filepath.Join(baseDir, "tls.crt"), filepath.Join(baseDir, "tls.key")) + if err != nil { + t.Fatalf("failed to load initial certificate: %v", err) + } + + // Create cert reloader with debug logging + ctx, cancel := context.WithCancel(logutil.NewTestLoggerIntoContext(context.Background())) + defer cancel() + + reloader, err := NewCertReloader(ctx, baseDir, &cert) + if err != nil { + t.Fatalf("failed to create cert reloader: %v", err) + } + + // Verify initial certificate is loaded + loadedCert := reloader.Get() + if loadedCert == nil { + t.Fatal("expected certificate to be loaded") + } + + // Verify it's the correct certificate by checking serial number + if len(loadedCert.Certificate) == 0 { + t.Fatal("loaded certificate has no certificate chain") + } + + x509Cert, err := x509.ParseCertificate(loadedCert.Certificate[0]) + if err != nil { + t.Fatalf("failed to parse loaded certificate: %v", err) + } + + if x509Cert.SerialNumber.Int64() != 1 { + t.Errorf("expected certificate serial number 1, got %d", x509Cert.SerialNumber.Int64()) + } +} + +func TestCertReloader_MultipleUpdates(t *testing.T) { + t.Parallel() + + initialSerialNumber := int64(1) + + // Create initial certificate + certPEM1, keyPEM1, err := createTestCertificate(initialSerialNumber) + if err != nil { + t.Fatalf("failed to create test certificate: %v", err) + } + + baseDir := setupK8sStyleCertDir(t, certPEM1, keyPEM1) + + // Load the initial certificate + cert, err := tls.LoadX509KeyPair(filepath.Join(baseDir, "tls.crt"), filepath.Join(baseDir, "tls.key")) + if err != nil { + t.Fatalf("failed to load initial certificate: %v", err) + } + + // Create cert reloader with debug logging + ctx, cancel := context.WithCancel(logutil.NewTestLoggerIntoContext(context.Background())) + defer cancel() + + reloader, err := NewCertReloader(ctx, baseDir, &cert) + if err != nil { + t.Fatalf("failed to create cert reloader: %v", err) + } + + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + + // Perform multiple sequential updates + for i := initialSerialNumber + 1; i <= 5; i++ { + certPEM, keyPEM, err := createTestCertificate(i) + if err != nil { + t.Fatalf("failed to create test certificate %d: %v", i, err) + } + + updateK8sStyleCerts(t, baseDir, certPEM, keyPEM) + + // Wait for reload + timeout := time.After(10 * time.Second) + + reloaded := false + for !reloaded { + select { + case <-timeout: + t.Fatalf("timeout waiting for certificate reload to serial %d", i) + case <-ticker.C: + currentCert := reloader.Get() + if len(currentCert.Certificate) > 0 { + x509Cert, err := x509.ParseCertificate(currentCert.Certificate[0]) + if err != nil { + t.Fatalf("failed to parse current certificate: %v", err) + } + if x509Cert.SerialNumber.Int64() == i { + reloaded = true + } + } + } + } + } + + // Verify final certificate + finalCert := reloader.Get() + x509FinalCert, err := x509.ParseCertificate(finalCert.Certificate[0]) + if err != nil { + t.Fatalf("failed to parse final certificate: %v", err) + } + + if x509FinalCert.SerialNumber.Int64() != 5 { + t.Errorf("expected final certificate serial number 5, got %d", x509FinalCert.SerialNumber.Int64()) + } +} + +func TestCertReloader_ErrorHandling(t *testing.T) { + t.Parallel() + + // Create initial valid certificate + certPEM1, keyPEM1, err := createTestCertificate(1) + if err != nil { + t.Fatalf("failed to create test certificate: %v", err) + } + + // Setup K8s-style directory + baseDir := setupK8sStyleCertDir(t, certPEM1, keyPEM1) + + // Load the initial certificate + cert, err := tls.LoadX509KeyPair(filepath.Join(baseDir, "tls.crt"), filepath.Join(baseDir, "tls.key")) + if err != nil { + t.Fatalf("failed to load initial certificate: %v", err) + } + + // Create cert reloader with debug logging + ctx, cancel := context.WithCancel(logutil.NewTestLoggerIntoContext(context.Background())) + defer cancel() + + reloader, err := NewCertReloader(ctx, baseDir, &cert) + if err != nil { + t.Fatalf("failed to create cert reloader: %v", err) + } + + // Verify initial certificate + initialCert := reloader.Get() + x509InitialCert, err := x509.ParseCertificate(initialCert.Certificate[0]) + if err != nil { + t.Fatalf("failed to parse initial certificate: %v", err) + } + + // Simulate an invalid certificate update (mismatched cert and key) + certPEM2, _, err := createTestCertificate(2) + if err != nil { + t.Fatalf("failed to create test certificate 2: %v", err) + } + _, keyPEM3, err := createTestCertificate(3) + if err != nil { + t.Fatalf("failed to create test certificate 3: %v", err) + } + + // Update with mismatched cert and key (should fail to load) + updateK8sStyleCerts(t, baseDir, certPEM2, keyPEM3) + + // Wait a bit to allow the reload attempt + time.Sleep(2 * time.Second) + + // Verify that the old certificate is still loaded (reload should have failed) + currentCert := reloader.Get() + x509CurrentCert, err := x509.ParseCertificate(currentCert.Certificate[0]) + if err != nil { + t.Fatalf("failed to parse current certificate: %v", err) + } + + if x509CurrentCert.SerialNumber.Int64() != x509InitialCert.SerialNumber.Int64() { + t.Errorf("expected certificate to remain unchanged on reload error, but serial number changed from %d to %d", + x509InitialCert.SerialNumber.Int64(), x509CurrentCert.SerialNumber.Int64()) + } +} + +// createTestCertificate generates a test TLS certificate and key pair with a given serial number. +// The serial number is used to uniquely identify different certificates in tests. +func createTestCertificate(serialNum int64) (certPEM, keyPEM []byte, err error) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate private key: %w", err) + } + + notBefore := time.Now() + notAfter := notBefore.Add(24 * time.Hour) + + template := x509.Certificate{ + SerialNumber: big.NewInt(serialNum), + Subject: pkix.Name{ + Organization: []string{"Test Org"}, + CommonName: "test-cert", + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return nil, nil, fmt.Errorf("failed to create certificate: %w", err) + } + + certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + + privBytes, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal private key: %w", err) + } + keyPEM = pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}) + + return certPEM, keyPEM, nil +} + +// setupK8sStyleCertDir creates a directory structure that mimics Kubernetes secret volume mounts. +// It creates: +// - A timestamped data directory (..2025_01_15_12_00_00.123456789/) +// - A ..data symlink pointing to the timestamped directory +// - tls.crt and tls.key symlinks pointing to ..data/tls.crt and ..data/tls.key +func setupK8sStyleCertDir(t *testing.T, certPEM, keyPEM []byte) string { + t.Helper() + + baseDir := t.TempDir() + + // Create initial timestamped directory + timestamp := time.Now().Format("..2006_01_02_15_04_05.000000000") + dataDir := filepath.Join(baseDir, timestamp) + if err := os.MkdirAll(dataDir, 0755); err != nil { + t.Fatalf("failed to create data directory: %v", err) + } + + // Write certificates to the timestamped directory + certPath := filepath.Join(dataDir, "tls.crt") + if err := os.WriteFile(certPath, certPEM, 0644); err != nil { + t.Fatalf("failed to write certificate: %v", err) + } + + keyPath := filepath.Join(dataDir, "tls.key") + if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil { + t.Fatalf("failed to write key: %v", err) + } + + // Create ..data symlink pointing to the timestamped directory + dotDataLink := filepath.Join(baseDir, "..data") + if err := os.Symlink(timestamp, dotDataLink); err != nil { + t.Fatalf("failed to create ..data symlink: %v", err) + } + + // Create tls.crt and tls.key symlinks pointing through ..data + tlsCrtLink := filepath.Join(baseDir, "tls.crt") + if err := os.Symlink(filepath.Join("..data", "tls.crt"), tlsCrtLink); err != nil { + t.Fatalf("failed to create tls.crt symlink: %v", err) + } + + tlsKeyLink := filepath.Join(baseDir, "tls.key") + if err := os.Symlink(filepath.Join("..data", "tls.key"), tlsKeyLink); err != nil { + t.Fatalf("failed to create tls.key symlink: %v", err) + } + + return baseDir +} + +// updateK8sStyleCerts simulates a Kubernetes secret update by: +// 1. Creating a new timestamped directory with new certificates +// 2. Atomically updating the ..data symlink to point to the new directory +// This mimics how Kubernetes updates secret volumes. +func updateK8sStyleCerts(t *testing.T, baseDir string, newCertPEM, newKeyPEM []byte) { + t.Helper() + + // Create new timestamped directory + newTimestamp := time.Now().Format("..2006_01_02_15_04_05.000000000") + newDataDir := filepath.Join(baseDir, newTimestamp) + if err := os.MkdirAll(newDataDir, 0755); err != nil { + t.Fatalf("failed to create new data directory: %v", err) + } + + // Write new certificates + certPath := filepath.Join(newDataDir, "tls.crt") + if err := os.WriteFile(certPath, newCertPEM, 0644); err != nil { + t.Fatalf("failed to write new certificate: %v", err) + } + + keyPath := filepath.Join(newDataDir, "tls.key") + if err := os.WriteFile(keyPath, newKeyPEM, 0600); err != nil { + t.Fatalf("failed to write new key: %v", err) + } + + // Atomically update ..data symlink + // In Kubernetes, this is done atomically by creating a new symlink and renaming it + dotDataLink := filepath.Join(baseDir, "..data") + dotDataTmp := filepath.Join(baseDir, "..data_tmp") + + // Create temporary symlink + if err := os.Symlink(newTimestamp, dotDataTmp); err != nil { + t.Fatalf("failed to create temporary ..data symlink: %v", err) + } + + // Atomically replace the old symlink + if err := os.Rename(dotDataTmp, dotDataLink); err != nil { + t.Fatalf("failed to update ..data symlink: %v", err) + } +} diff --git a/pkg/epp/server/runserver.go b/pkg/epp/server/runserver.go index c3037175e..8075c7c64 100644 --- a/pkg/epp/server/runserver.go +++ b/pkg/epp/server/runserver.go @@ -54,6 +54,7 @@ type ExtProcServerRunner struct { SecureServing bool HealthChecking bool CertPath string + EnableCertReload bool RefreshPrometheusMetricsInterval time.Duration MetricsStalenessThreshold time.Duration Director *requestcontrol.Director @@ -82,6 +83,7 @@ const ( DefaultLoraInfoMetric = "vllm:lora_requests_info" // default for --lora-info-metric DefaultCacheInfoMetric = "vllm:cache_config_info" // default for --cache-info-metric DefaultCertPath = "" // default for --cert-path + DefaultCertReload = false // default for --enable-cert-reload DefaultConfigFile = "" // default for --config-file DefaultConfigText = "" // default for --config-text DefaultPoolGroup = "inference.networking.k8s.io" // default for --pool-group @@ -162,9 +164,22 @@ func (r *ExtProcServerRunner) AsRunnable(logger logr.Logger) manager.Runnable { return fmt.Errorf("failed to create self signed certificate - %w", err) } - creds := credentials.NewTLS(&tls.Config{ - Certificates: []tls.Certificate{cert}, - }) + var creds credentials.TransportCredentials + if r.CertPath != "" && r.EnableCertReload { + reloader, err := common.NewCertReloader(ctx, r.CertPath, &cert) + if err != nil { + return fmt.Errorf("failed to create cert reloader: %w", err) + } + creds = credentials.NewTLS(&tls.Config{ + GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + return reloader.Get(), nil + }, + }) + } else { + creds = credentials.NewTLS(&tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + } // Init the server. srv = grpc.NewServer(grpc.Creds(creds)) } else { diff --git a/pkg/epp/util/logging/logger.go b/pkg/epp/util/logging/logger.go index 5e6ed88da..284e1bd9d 100644 --- a/pkg/epp/util/logging/logger.go +++ b/pkg/epp/util/logging/logger.go @@ -21,16 +21,21 @@ import ( "github.com/go-logr/logr" uberzap "go.uber.org/zap" + "go.uber.org/zap/zapcore" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" ) // NewTestLogger creates a new Zap logger using the dev mode. func NewTestLogger() logr.Logger { - return zap.New(zap.UseDevMode(true), zap.RawZapOpts(uberzap.AddCaller())) + return zap.New( + zap.UseDevMode(true), + zap.Level(uberzap.NewAtomicLevelAt(zapcore.Level(-1*TRACE))), + zap.RawZapOpts(uberzap.AddCaller()), + ) } // NewTestLoggerIntoContext creates a new Zap logger using the dev mode and inserts it into the given context. func NewTestLoggerIntoContext(ctx context.Context) context.Context { - return log.IntoContext(ctx, zap.New(zap.UseDevMode(true), zap.RawZapOpts(uberzap.AddCaller()))) + return log.IntoContext(ctx, NewTestLogger()) }