From 95e4d632d01bc9d195a223d28c2d309436767b5c Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Fri, 7 Nov 2025 16:59:15 -0700 Subject: [PATCH] Filter CMAP/SDAM events during pool initialization for awaitMinPoolSizeMS --- internal/integration/unified/client_entity.go | 129 +++++++-- .../integration/unified/client_entity_test.go | 253 ++++++++++++++++++ internal/integration/unified/entity.go | 9 +- .../integration/unified/event_verification.go | 18 +- 4 files changed, 367 insertions(+), 42 deletions(-) create mode 100644 internal/integration/unified/client_entity_test.go diff --git a/internal/integration/unified/client_entity.go b/internal/integration/unified/client_entity.go index bc981793df..c21cbfffd2 100644 --- a/internal/integration/unified/client_entity.go +++ b/internal/integration/unified/client_entity.go @@ -38,6 +38,50 @@ var securitySensitiveCommands = []string{ "createUser", "updateUser", "copydbgetnonce", "copydbsaslstart", "copydb", } +// eventSequencer allows for sequence-based event filtering for +// awaitMinPoolSizeMS support. +// +// Per the unified test format spec, when awaitMinPoolSizeMS is specified, any +// CMAP and SDAM events that occur during connection pool initialization +// (before minPoolSize is reached) must be ignored. We track this by +// assigning a monotonically increasing sequence number to each event as it's +// recorded. After pool initialization completes, we set eventCutoffSeq to the +// current sequence number. Event accessors for CMAP and SDAM types then +// filter out any events with sequence <= eventCutoffSeq. +type eventSequencer struct { + counter atomic.Int64 + cutoff int64 + + // pool events are heterogeneous, so we track their sequence separately + poolSeq []int64 + seqByEventType map[monitoringEventType][]int64 +} + +// setCutoff marks the current sequence as the filtering cutoff point. +func (es *eventSequencer) setCutoff() { + es.cutoff = es.counter.Load() +} + +// recordEvent stores the sequence number for a given event type. +func (es *eventSequencer) recordEvent(eventType monitoringEventType) { + next := es.counter.Add(1) + es.seqByEventType[eventType] = append(es.seqByEventType[eventType], next) +} + +func (es *eventSequencer) recordPooledEvent() { + next := es.counter.Add(1) + es.poolSeq = append(es.poolSeq, next) +} + +// shouldFilter returns true if the event at the given index should be filtered. +func (es *eventSequencer) shouldFilter(eventType monitoringEventType, index int) bool { + if es.cutoff == 0 { + return false + } + + return es.seqByEventType[eventType][index] <= es.cutoff +} + // clientEntity is a wrapper for a mongo.Client object that also holds additional information required during test // execution. type clientEntity struct { @@ -72,30 +116,8 @@ type clientEntity struct { entityMap *EntityMap - logQueue chan orderedLogMessage -} - -// awaitMinimumPoolSize waits for the client's connection pool to reach the -// specified minimum size. This is a best effort operation that times out after -// some predefined amount of time to avoid blocking tests indefinitely. -func awaitMinimumPoolSize(ctx context.Context, entity *clientEntity, minPoolSize uint64) error { - // Don't spend longer than 500ms awaiting minPoolSize. - awaitCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) - defer cancel() - - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-awaitCtx.Done(): - return fmt.Errorf("timed out waiting for client to reach minPoolSize") - case <-ticker.C: - if uint64(entity.eventsCount[connectionReadyEvent]) >= minPoolSize { - return nil - } - } - } + logQueue chan orderedLogMessage + eventSequencer eventSequencer } func newClientEntity(ctx context.Context, em *EntityMap, entityOptions *entityOptions) (*clientEntity, error) { @@ -118,6 +140,9 @@ func newClientEntity(ctx context.Context, em *EntityMap, entityOptions *entityOp serverDescriptionChangedEventsCount: make(map[serverDescriptionChangedEventInfo]int32), entityMap: em, observeSensitiveCommands: entityOptions.ObserveSensitiveCommands, + eventSequencer: eventSequencer{ + seqByEventType: make(map[monitoringEventType][]int64), + }, } entity.setRecordEvents(true) @@ -226,8 +251,9 @@ func newClientEntity(ctx context.Context, em *EntityMap, entityOptions *entityOp return nil, fmt.Errorf("error creating mongo.Client: %w", err) } - if entityOptions.AwaitMinPoolSize && clientOpts.MinPoolSize != nil && *clientOpts.MinPoolSize > 0 { - if err := awaitMinimumPoolSize(ctx, entity, *clientOpts.MinPoolSize); err != nil { + if entityOptions.AwaitMinPoolSizeMS != nil && *entityOptions.AwaitMinPoolSizeMS > 0 && + clientOpts.MinPoolSize != nil && *clientOpts.MinPoolSize > 0 { + if err := awaitMinimumPoolSize(ctx, entity, *clientOpts.MinPoolSize, *entityOptions.AwaitMinPoolSizeMS); err != nil { return nil, err } } @@ -326,8 +352,21 @@ func (c *clientEntity) failedEvents() []*event.CommandFailedEvent { return events } -func (c *clientEntity) poolEvents() []*event.PoolEvent { - return c.pooled +// filterEventsBySeq filters events by sequence number using the provided +// sequence slice. See comments on eventSequencer for more details. +func filterEventsBySeq[T any](c *clientEntity, events []T, seqSlice []int64) []T { + if c.eventSequencer.cutoff == 0 { + return events + } + + var filtered []T + for i, evt := range events { + if seqSlice[i] > c.eventSequencer.cutoff { + filtered = append(filtered, evt) + } + } + + return filtered } func (c *clientEntity) numberConnectionsCheckedOut() int32 { @@ -517,6 +556,7 @@ func (c *clientEntity) processPoolEvent(evt *event.PoolEvent) { eventType := monitoringEventTypeFromPoolEvent(evt) if _, ok := c.observedEvents[eventType]; ok { c.pooled = append(c.pooled, evt) + c.eventSequencer.recordPooledEvent() } c.addEventsCount(eventType) @@ -539,6 +579,7 @@ func (c *clientEntity) processServerDescriptionChangedEvent(evt *event.ServerDes if _, ok := c.observedEvents[serverDescriptionChangedEvent]; ok { c.serverDescriptionChanged = append(c.serverDescriptionChanged, evt) + c.eventSequencer.recordEvent(serverDescriptionChangedEvent) } // Record object-specific unified spec test data on an event. @@ -558,6 +599,7 @@ func (c *clientEntity) processServerHeartbeatFailedEvent(evt *event.ServerHeartb if _, ok := c.observedEvents[serverHeartbeatFailedEvent]; ok { c.serverHeartbeatFailedEvent = append(c.serverHeartbeatFailedEvent, evt) + c.eventSequencer.recordEvent(serverHeartbeatFailedEvent) } c.addEventsCount(serverHeartbeatFailedEvent) @@ -573,6 +615,7 @@ func (c *clientEntity) processServerHeartbeatStartedEvent(evt *event.ServerHeart if _, ok := c.observedEvents[serverHeartbeatStartedEvent]; ok { c.serverHeartbeatStartedEvent = append(c.serverHeartbeatStartedEvent, evt) + c.eventSequencer.recordEvent(serverHeartbeatStartedEvent) } c.addEventsCount(serverHeartbeatStartedEvent) @@ -588,6 +631,7 @@ func (c *clientEntity) processServerHeartbeatSucceededEvent(evt *event.ServerHea if _, ok := c.observedEvents[serverHeartbeatSucceededEvent]; ok { c.serverHeartbeatSucceeded = append(c.serverHeartbeatSucceeded, evt) + c.eventSequencer.recordEvent(serverHeartbeatSucceededEvent) } c.addEventsCount(serverHeartbeatSucceededEvent) @@ -603,6 +647,7 @@ func (c *clientEntity) processTopologyDescriptionChangedEvent(evt *event.Topolog if _, ok := c.observedEvents[topologyDescriptionChangedEvent]; ok { c.topologyDescriptionChanged = append(c.topologyDescriptionChanged, evt) + c.eventSequencer.recordEvent(topologyDescriptionChangedEvent) } c.addEventsCount(topologyDescriptionChangedEvent) @@ -618,6 +663,7 @@ func (c *clientEntity) processTopologyOpeningEvent(evt *event.TopologyOpeningEve if _, ok := c.observedEvents[topologyOpeningEvent]; ok { c.topologyOpening = append(c.topologyOpening, evt) + c.eventSequencer.recordEvent(topologyOpeningEvent) } c.addEventsCount(topologyOpeningEvent) @@ -633,6 +679,7 @@ func (c *clientEntity) processTopologyClosedEvent(evt *event.TopologyClosedEvent if _, ok := c.observedEvents[topologyClosedEvent]; ok { c.topologyClosed = append(c.topologyClosed, evt) + c.eventSequencer.recordEvent(topologyClosedEvent) } c.addEventsCount(topologyClosedEvent) @@ -724,3 +771,29 @@ func evaluateUseMultipleMongoses(clientOpts *options.ClientOptions, useMultipleM } return nil } + +// awaitMinimumPoolSize waits for the client's connection pool to reach the +// specified minimum size, then clears all CMAP and SDAM events that occurred +// during pool initialization. +func awaitMinimumPoolSize(ctx context.Context, entity *clientEntity, minPoolSize uint64, timeoutMS int) error { + awaitCtx, cancel := context.WithTimeout(ctx, time.Duration(timeoutMS)*time.Millisecond) + defer cancel() + + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-awaitCtx.Done(): + return fmt.Errorf("timed out waiting for client to reach minPoolSize") + case <-ticker.C: + if uint64(entity.eventsCount[connectionReadyEvent]) >= minPoolSize { + // Clear all CMAP and SDAM events that occurred during pool + // initialization. + entity.eventSequencer.setCutoff() + + return nil + } + } + } +} diff --git a/internal/integration/unified/client_entity_test.go b/internal/integration/unified/client_entity_test.go new file mode 100644 index 0000000000..9cb08b0f76 --- /dev/null +++ b/internal/integration/unified/client_entity_test.go @@ -0,0 +1,253 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package unified + +import ( + "sync/atomic" + "testing" + + "go.mongodb.org/mongo-driver/v2/event" + "go.mongodb.org/mongo-driver/v2/internal/assert" +) + +// Helper functions to condense event recording in tests +func recordPoolEvent(c *clientEntity) { + c.pooled = append(c.pooled, &event.PoolEvent{}) + c.eventSequencer.recordPooledEvent() +} + +func recordServerDescChanged(c *clientEntity) { + c.serverDescriptionChanged = append(c.serverDescriptionChanged, &event.ServerDescriptionChangedEvent{}) + c.eventSequencer.recordEvent(serverDescriptionChangedEvent) +} + +func recordTopologyOpening(c *clientEntity) { + c.topologyOpening = append(c.topologyOpening, &event.TopologyOpeningEvent{}) + c.eventSequencer.recordEvent(topologyOpeningEvent) +} + +func recordHeartbeatSucceeded(c *clientEntity) { + c.serverHeartbeatSucceeded = append(c.serverHeartbeatSucceeded, &event.ServerHeartbeatSucceededEvent{}) + c.eventSequencer.recordEvent(serverHeartbeatSucceededEvent) +} + +func Test_eventSequencer(t *testing.T) { + tests := []struct { + name string + setupEvents func(*clientEntity) + cutoffAfter int // Set cutoff after this many events (0 = no cutoff) + expectedPooled int + expectedSDAM map[monitoringEventType]int + }{ + { + name: "no cutoff filters nothing", + cutoffAfter: 0, + setupEvents: func(c *clientEntity) { + recordPoolEvent(c) + recordPoolEvent(c) + recordPoolEvent(c) + recordServerDescChanged(c) + recordServerDescChanged(c) + }, + expectedPooled: 3, + expectedSDAM: map[monitoringEventType]int{ + serverDescriptionChangedEvent: 2, + }, + }, + { + name: "cutoff after 2 pool events filters first 2", + cutoffAfter: 2, + setupEvents: func(c *clientEntity) { + recordPoolEvent(c) + recordPoolEvent(c) + // Cutoff will be set here (after event 2) + recordPoolEvent(c) + recordPoolEvent(c) + recordPoolEvent(c) + }, + expectedPooled: 3, // Events 3, 4, 5 + expectedSDAM: map[monitoringEventType]int{}, + }, + { + name: "cutoff filters mixed pool and SDAM events", + cutoffAfter: 4, + setupEvents: func(c *clientEntity) { + recordPoolEvent(c) + recordServerDescChanged(c) + recordPoolEvent(c) + recordTopologyOpening(c) + // Cutoff will be set here (after event 4) + recordPoolEvent(c) + recordServerDescChanged(c) + recordTopologyOpening(c) + }, + expectedPooled: 1, + expectedSDAM: map[monitoringEventType]int{ + serverDescriptionChangedEvent: 1, + topologyOpeningEvent: 1, + }, + }, + { + name: "cutoff at beginning filters nothing", + cutoffAfter: 0, + setupEvents: func(c *clientEntity) { + // Cutoff will be set immediately (before any events) + recordPoolEvent(c) + recordHeartbeatSucceeded(c) + }, + expectedPooled: 1, + expectedSDAM: map[monitoringEventType]int{ + serverHeartbeatSucceededEvent: 1, + }, + }, + { + name: "cutoff after all events filters everything", + cutoffAfter: 3, + setupEvents: func(c *clientEntity) { + recordPoolEvent(c) + recordPoolEvent(c) + recordServerDescChanged(c) + // Cutoff will be set here (after all 3 events) + }, + expectedPooled: 0, + expectedSDAM: map[monitoringEventType]int{ + serverDescriptionChangedEvent: 0, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a minimal clientEntity + client := &clientEntity{ + eventSequencer: eventSequencer{ + seqByEventType: make(map[monitoringEventType][]int64), + }, + } + + // Setup events + tt.setupEvents(client) + + // Set cutoff if specified + if tt.cutoffAfter > 0 { + // Manually set cutoff to the specified event sequence + client.eventSequencer.cutoff = int64(tt.cutoffAfter) + } + + // Test pool event filtering + filteredPool := filterEventsBySeq(client, client.pooled, client.eventSequencer.poolSeq) + assert.Equal(t, tt.expectedPooled, len(filteredPool), "pool events count mismatch") + + // Test SDAM event filtering + for eventType, expectedCount := range tt.expectedSDAM { + var actualCount int + seqs := client.eventSequencer.seqByEventType[eventType] + + switch eventType { + case serverDescriptionChangedEvent: + actualCount = len(filterEventsBySeq(client, client.serverDescriptionChanged, seqs)) + case serverHeartbeatSucceededEvent: + actualCount = len(filterEventsBySeq(client, client.serverHeartbeatSucceeded, seqs)) + case topologyOpeningEvent: + actualCount = len(filterEventsBySeq(client, client.topologyOpening, seqs)) + } + + assert.Equal(t, expectedCount, actualCount, "%s count mismatch", eventType) + } + }) + } +} + +func Test_eventSequencer_setCutoff(t *testing.T) { + client := &clientEntity{ + eventSequencer: eventSequencer{ + seqByEventType: make(map[monitoringEventType][]int64), + }, + } + + // Record some events + recordPoolEvent(client) + recordPoolEvent(client) + + // Verify counter is at 2 + assert.Equal(t, int64(2), client.eventSequencer.counter.Load(), "counter should be 2") + + // Set cutoff + client.eventSequencer.setCutoff() + + // Verify cutoff matches counter + assert.Equal(t, int64(2), client.eventSequencer.cutoff, "cutoff should be 2") + + // Record more events + recordPoolEvent(client) + + // Verify counter incremented but cutoff didn't + assert.Equal(t, int64(3), client.eventSequencer.counter.Load(), "counter should be 3") + assert.Equal(t, int64(2), client.eventSequencer.cutoff, "cutoff should still be 2") +} + +func Test_eventSequencer_shouldFilter(t *testing.T) { + es := &eventSequencer{ + seqByEventType: map[monitoringEventType][]int64{ + serverDescriptionChangedEvent: {1, 2, 3, 4, 5}, + }, + } + es.counter = atomic.Int64{} + es.counter.Store(5) + + tests := []struct { + name string + cutoff int64 + eventType monitoringEventType + index int + expected bool + }{ + { + name: "no cutoff", + cutoff: 0, + eventType: serverDescriptionChangedEvent, + index: 0, + expected: false, + }, + { + name: "before cutoff", + cutoff: 3, + eventType: serverDescriptionChangedEvent, + index: 0, + expected: true, // seq=1 <= 3 + }, + { + name: "at cutoff", + cutoff: 3, + eventType: serverDescriptionChangedEvent, + index: 2, + expected: true, // seq=3 <= 3 + }, + { + name: "after cutoff", + cutoff: 3, + eventType: serverDescriptionChangedEvent, + index: 3, + expected: false, // seq=4 > 3 + }, + { + name: "last event", + cutoff: 3, + eventType: serverDescriptionChangedEvent, + index: 4, + expected: false, // seq=5 > 3 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + es.cutoff = tt.cutoff + result := es.shouldFilter(tt.eventType, tt.index) + assert.Equal(t, tt.expected, result, "shouldFilter result mismatch") + }) + } +} diff --git a/internal/integration/unified/entity.go b/internal/integration/unified/entity.go index b1b827a124..8222233290 100644 --- a/internal/integration/unified/entity.go +++ b/internal/integration/unified/entity.go @@ -83,11 +83,10 @@ type entityOptions struct { ClientEncryptionOpts *clientEncryptionOpts `bson:"clientEncryptionOpts"` - // If true, the unified spec runner must wait for the connection pool to be - // populated for all servers according to the minPoolSize option. If false, - // not specified, or if minPoolSize equals 0, there is no need to wait for any - // specific pool state. - AwaitMinPoolSize bool `bson:"awaitMinPoolSize"` + // Maximum duration (in milliseconds) that the test runner MUST wait for each + // connection pool to be populated with minPoolSize. Any CMAP and SDAM events + // that occur before the pool is populated will be ignored. + AwaitMinPoolSizeMS *int `bson:"awaitMinPoolSizeMS"` } func (eo *entityOptions) setHeartbeatFrequencyMS(freq time.Duration) { diff --git a/internal/integration/unified/event_verification.go b/internal/integration/unified/event_verification.go index 0521f0653e..b3da009081 100644 --- a/internal/integration/unified/event_verification.go +++ b/internal/integration/unified/event_verification.go @@ -312,7 +312,7 @@ func verifyCommandEvents(ctx context.Context, client *clientEntity, expectedEven } func verifyCMAPEvents(client *clientEntity, expectedEvents *expectedEvents) error { - pooled := client.poolEvents() + pooled := filterEventsBySeq(client, client.pooled, client.eventSequencer.poolSeq) if len(expectedEvents.CMAPEvents) == 0 && len(pooled) != 0 { return fmt.Errorf("expected no cmap events to be sent but got %s", stringifyEventsForClient(client)) } @@ -443,7 +443,7 @@ func stringifyEventsForClient(client *clientEntity) string { } str.WriteString("\nPool Events\n\n") - for _, evt := range client.poolEvents() { + for _, evt := range filterEventsBySeq(client, client.pooled, client.eventSequencer.poolSeq) { str.WriteString(fmt.Sprintf("[%s] Event Type: %q\n", evt.Address, evt.Type)) } @@ -522,13 +522,13 @@ func getNextTopologyClosedEvent( func verifySDAMEvents(client *clientEntity, expectedEvents *expectedEvents) error { var ( - changed = client.serverDescriptionChanged - started = client.serverHeartbeatStartedEvent - succeeded = client.serverHeartbeatSucceeded - failed = client.serverHeartbeatFailedEvent - tchanged = client.topologyDescriptionChanged - topening = client.topologyOpening - tclosed = client.topologyClosed + changed = filterEventsBySeq(client, client.serverDescriptionChanged, client.eventSequencer.seqByEventType[serverDescriptionChangedEvent]) + started = filterEventsBySeq(client, client.serverHeartbeatStartedEvent, client.eventSequencer.seqByEventType[serverHeartbeatStartedEvent]) + succeeded = filterEventsBySeq(client, client.serverHeartbeatSucceeded, client.eventSequencer.seqByEventType[serverHeartbeatSucceededEvent]) + failed = filterEventsBySeq(client, client.serverHeartbeatFailedEvent, client.eventSequencer.seqByEventType[serverHeartbeatFailedEvent]) + tchanged = filterEventsBySeq(client, client.topologyDescriptionChanged, client.eventSequencer.seqByEventType[topologyDescriptionChangedEvent]) + topening = filterEventsBySeq(client, client.topologyOpening, client.eventSequencer.seqByEventType[topologyOpeningEvent]) + tclosed = filterEventsBySeq(client, client.topologyClosed, client.eventSequencer.seqByEventType[topologyClosedEvent]) ) vol := func() int {