@@ -270,21 +270,30 @@ func checkExpectations(mt *mtest.T, expectations *[]*expectation, id0, id1 bson.
270270 return
271271 }
272272
273- for idx , expectation := range * expectations {
274- var err error
273+ startedEvents := make ([]* cmdStartedEvt , 0 , len (* expectations ))
274+ succeededEvents := make ([]* cmdSucceededEvt , 0 , len (* expectations ))
275+ failedEvents := make ([]* cmdFailedEvt , 0 , len (* expectations ))
275276
277+ for _ , expectation := range * expectations {
276278 if expectation .CommandStartedEvent != nil {
277- err = compareStartedEvent ( mt , expectation , id0 , id1 )
279+ startedEvents = append ( startedEvents , expectation . CommandStartedEvent )
278280 }
279281 if expectation .CommandSucceededEvent != nil {
280- err = compareSucceededEvent ( mt , expectation )
282+ succeededEvents = append ( succeededEvents , expectation . CommandSucceededEvent )
281283 }
282284 if expectation .CommandFailedEvent != nil {
283- err = compareFailedEvent ( mt , expectation )
285+ failedEvents = append ( failedEvents , expectation . CommandFailedEvent )
284286 }
285-
286- assert .Nil (mt , err , "expectation comparison error at index %v: %s" , idx , err )
287287 }
288+
289+ var err error
290+ err = compareStartedEvents (mt , startedEvents , id0 , id1 )
291+ assert .Nil (mt , err , "expectation comparison %s" , err )
292+ err = compareSucceededEvents (mt , succeededEvents )
293+ assert .Nil (mt , err , "expectation comparison %s" , err )
294+ err = compareFailedEvents (mt , failedEvents )
295+ assert .Nil (mt , err , "expectation comparison %s" , err )
296+
288297}
289298
290299// newMatchError appends `expected` and `actual` BSON data to an error.
@@ -298,83 +307,104 @@ func newMatchError(mt *mtest.T, expected bson.Raw, actual bson.Raw, format strin
298307 return fmt .Errorf ("%s\n Expected %s\n Got: %s" , msg , string (expectedJSON ), string (actualJSON ))
299308}
300309
301- func compareStartedEvent (mt * mtest.T , expectation * expectation , id0 , id1 bson.Raw ) error {
310+ func compareStartedEvents (mt * mtest.T , expectations [] * cmdStartedEvt , id0 , id1 bson.Raw ) error {
302311 mt .Helper ()
303312
304- expected := expectation .CommandStartedEvent
305-
306- if len (expected .Extra ) > 0 {
307- return fmt .Errorf ("unrecognized fields for CommandStartedEvent: %v" , expected .Extra )
308- }
309-
310- evt := mt .GetStartedEvent ()
311- if evt == nil {
312- return errors .New ("expected CommandStartedEvent, got nil" )
313- }
314-
315- if expected .CommandName != "" && expected .CommandName != evt .CommandName {
316- return fmt .Errorf ("command name mismatch; expected %s, got %s" , expected .CommandName , evt .CommandName )
317- }
318- if expected .DatabaseName != "" && expected .DatabaseName != evt .DatabaseName {
319- return fmt .Errorf ("database name mismatch; expected %s, got %s" , expected .DatabaseName , evt .DatabaseName )
320- }
321-
322- eElems , err := expected .Command .Elements ()
323- if err != nil {
324- return fmt .Errorf ("error getting expected command elements: %s" , err )
313+ expectedCmds := make (map [string ]bool )
314+ for _ , expected := range expectations {
315+ expectedCmds [expected .CommandName ] = true
325316 }
326317
327- for _ , elem := range eElems {
328- key := elem .Key ()
329- val := elem .Value ()
330-
331- actualVal , err := evt .Command .LookupErr (key )
318+ compare := func (expected * cmdStartedEvt ) error {
319+ if len (expected .Extra ) > 0 {
320+ return fmt .Errorf ("unrecognized fields for CommandStartedEvent: %v" , expected .Extra )
321+ }
332322
333- // Keys that may be nil
334- if val .Type == bson .TypeNull {
335- // Expected value is BSON null. Expect the actual field to be omitted.
336- if errors .Is (err , bsoncore .ErrElementNotFound ) {
337- continue
323+ var evt * event.CommandStartedEvent
324+ // skip events not in expectations
325+ for {
326+ evt = mt .GetStartedEvent ()
327+ if evt == nil {
328+ return fmt .Errorf ("expected CommandStartedEvent %s, got nil" , expected .CommandName )
338329 }
339- if err != nil {
340- return newMatchError (mt , expected .Command , evt .Command , "expected key %q to be omitted but got error: %v" , key , err )
330+ if expected .CommandName == "" {
331+ break
332+ } else if v , ok := expectedCmds [evt .CommandName ]; ok && v {
333+ break
341334 }
342- return newMatchError (mt , expected .Command , evt .Command , "expected key %q to be omitted but got %q" , key , actualVal )
343335 }
344- assert .Nil (mt , err , "expected command to contain key %q" , key )
345336
346- if key == "batchSize" {
347- // Some command monitoring tests expect that the driver will send a lower batch size if the required batch
348- // size is lower than the operation limit. We only do this for legacy servers <= 3.0 because those server
349- // versions do not support the limit option, but not for 3.2+. We've already validated that the command
350- // contains a batchSize field above and we can skip the actual value comparison below.
351- continue
337+ if expected .CommandName != "" && expected .CommandName != evt .CommandName {
338+ return fmt .Errorf ("command name mismatch for started event; expected %s, got %s" , expected .CommandName , evt .CommandName )
339+ }
340+ if expected .DatabaseName != "" && expected .DatabaseName != evt .DatabaseName {
341+ return fmt .Errorf ("database name mismatch; expected %s, got %s" , expected .DatabaseName , evt .DatabaseName )
352342 }
353343
354- switch key {
355- case "lsid" :
356- sessName := val .StringValue ()
357- var expectedID bson.Raw
358- actualID := actualVal .Document ()
344+ eElems , err := expected .Command .Elements ()
345+ if err != nil {
346+ return fmt .Errorf ("error getting expected command elements: %s" , err )
347+ }
359348
360- switch sessName {
361- case "session0" :
362- expectedID = id0
363- case "session1" :
364- expectedID = id1
365- default :
366- return newMatchError (mt , expected .Command , evt .Command , "unrecognized session identifier in command document: %s" , sessName )
349+ for _ , elem := range eElems {
350+ key := elem .Key ()
351+ val := elem .Value ()
352+
353+ actualVal , err := evt .Command .LookupErr (key )
354+
355+ // Keys that may be nil
356+ if val .Type == bson .TypeNull {
357+ // Expected value is BSON null. Expect the actual field to be omitted.
358+ if errors .Is (err , bsoncore .ErrElementNotFound ) {
359+ continue
360+ }
361+ if err != nil {
362+ return newMatchError (mt , expected .Command , evt .Command , "expected key %q to be omitted but got error: %v" , key , err )
363+ }
364+ return newMatchError (mt , expected .Command , evt .Command , "expected key %q to be omitted but got %q" , key , actualVal )
367365 }
366+ assert .Nil (mt , err , "expected command to contain key %q" , key )
368367
369- if ! bytes .Equal (expectedID , actualID ) {
370- return newMatchError (mt , expected .Command , evt .Command , "session ID mismatch for session %s; expected %s, got %s" , sessName , expectedID ,
371- actualID )
368+ if key == "batchSize" {
369+ // Some command monitoring tests expect that the driver will send a lower batch size if the required batch
370+ // size is lower than the operation limit. We only do this for legacy servers <= 3.0 because those server
371+ // versions do not support the limit option, but not for 3.2+. We've already validated that the command
372+ // contains a batchSize field above and we can skip the actual value comparison below.
373+ continue
372374 }
373- default :
374- if err := compareValues (mt , key , val , actualVal ); err != nil {
375- return newMatchError (mt , expected .Command , evt .Command , "%s" , err )
375+
376+ switch key {
377+ case "lsid" :
378+ sessName := val .StringValue ()
379+ var expectedID bson.Raw
380+ actualID := actualVal .Document ()
381+
382+ switch sessName {
383+ case "session0" :
384+ expectedID = id0
385+ case "session1" :
386+ expectedID = id1
387+ default :
388+ return newMatchError (mt , expected .Command , evt .Command , "unrecognized session identifier in command document: %s" , sessName )
389+ }
390+
391+ if ! bytes .Equal (expectedID , actualID ) {
392+ return newMatchError (mt , expected .Command , evt .Command , "session ID mismatch for session %s; expected %s, got %s" , sessName , expectedID ,
393+ actualID )
394+ }
395+ default :
396+ if err := compareValues (mt , key , val , actualVal ); err != nil {
397+ return newMatchError (mt , expected .Command , evt .Command , "%s" , err )
398+ }
376399 }
377400 }
401+ return nil
402+ }
403+ for idx , expected := range expectations {
404+ err := compare (expected )
405+ if err != nil {
406+ return fmt .Errorf ("error at index %d: %s" , idx , err )
407+ }
378408 }
379409 return nil
380410}
@@ -416,60 +446,106 @@ func compareWriteErrors(mt *mtest.T, expected, actual bson.Raw) error {
416446 return nil
417447}
418448
419- func compareSucceededEvent (mt * mtest.T , expectation * expectation ) error {
449+ func compareSucceededEvents (mt * mtest.T , expectations [] * cmdSucceededEvt ) error {
420450 mt .Helper ()
421451
422- expected := expectation .CommandSucceededEvent
423- if len (expected .Extra ) > 0 {
424- return fmt .Errorf ("unrecognized fields for CommandSucceededEvent: %v" , expected .Extra )
425- }
426- evt := mt .GetSucceededEvent ()
427- if evt == nil {
428- return errors .New ("expected CommandSucceededEvent, got nil" )
452+ expectedCmds := make (map [string ]bool )
453+ for _ , expected := range expectations {
454+ expectedCmds [expected .CommandName ] = true
429455 }
430456
431- if expected .CommandName != "" && expected .CommandName != evt .CommandName {
432- return fmt .Errorf ("command name mismatch; expected %s, got %s" , expected .CommandName , evt .CommandName )
433- }
457+ compare := func (expected * cmdSucceededEvt ) error {
458+ if len (expected .Extra ) > 0 {
459+ return fmt .Errorf ("unrecognized fields for CommandSucceededEvent: %v" , expected .Extra )
460+ }
434461
435- eElems , err := expected .Reply .Elements ()
436- if err != nil {
437- return fmt .Errorf ("error getting expected reply elements: %s" , err )
438- }
462+ var evt * event.CommandSucceededEvent
463+ // skip events not in expectations
464+ for {
465+ evt = mt .GetSucceededEvent ()
466+ if evt == nil {
467+ return fmt .Errorf ("expected CommandSucceededEvent %s, got nil" , expected .CommandName )
468+ }
469+ if expected .CommandName == "" {
470+ break
471+ } else if v , ok := expectedCmds [evt .CommandName ]; ok && v {
472+ break
473+ }
474+ }
439475
440- for _ , elem := range eElems {
441- key := elem .Key ()
442- val := elem .Value ()
443- actualVal := evt .Reply .Lookup (key )
476+ if expected .CommandName != "" && expected .CommandName != evt .CommandName {
477+ return fmt .Errorf ("command name mismatch for succeeded event; expected %s, got %s" , expected .CommandName , evt .CommandName )
478+ }
444479
445- switch key {
446- case "writeErrors" :
447- if err = compareWriteErrors (mt , val .Array (), actualVal .Array ()); err != nil {
448- return newMatchError (mt , expected .Reply , evt .Reply , "%s" , err )
449- }
450- default :
451- if err := compareValues (mt , key , val , actualVal ); err != nil {
452- return newMatchError (mt , expected .Reply , evt .Reply , "%s" , err )
480+ eElems , err := expected .Reply .Elements ()
481+ if err != nil {
482+ return fmt .Errorf ("error getting expected reply elements: %s" , err )
483+ }
484+
485+ for _ , elem := range eElems {
486+ key := elem .Key ()
487+ val := elem .Value ()
488+ actualVal := evt .Reply .Lookup (key )
489+
490+ switch key {
491+ case "writeErrors" :
492+ if err = compareWriteErrors (mt , val .Array (), actualVal .Array ()); err != nil {
493+ return newMatchError (mt , expected .Reply , evt .Reply , "%s" , err )
494+ }
495+ default :
496+ if err := compareValues (mt , key , val , actualVal ); err != nil {
497+ return newMatchError (mt , expected .Reply , evt .Reply , "%s" , err )
498+ }
453499 }
454500 }
501+ return nil
502+ }
503+ for idx , expected := range expectations {
504+ err := compare (expected )
505+ if err != nil {
506+ return fmt .Errorf ("error at index %d: %s" , idx , err )
507+ }
455508 }
456509 return nil
457510}
458511
459- func compareFailedEvent (mt * mtest.T , expectation * expectation ) error {
512+ func compareFailedEvents (mt * mtest.T , expectations [] * cmdFailedEvt ) error {
460513 mt .Helper ()
461514
462- expected := expectation .CommandFailedEvent
463- if len (expected .Extra ) > 0 {
464- return fmt .Errorf ("unrecognized fields for CommandFailedEvent: %v" , expected .Extra )
465- }
466- evt := mt .GetFailedEvent ()
467- if evt == nil {
468- return errors .New ("expected CommandFailedEvent, got nil" )
515+ expectedCmds := make (map [string ]bool )
516+ for _ , expected := range expectations {
517+ expectedCmds [expected .CommandName ] = true
469518 }
470519
471- if expected .CommandName != "" && expected .CommandName != evt .CommandName {
472- return fmt .Errorf ("command name mismatch; expected %s, got %s" , expected .CommandName , evt .CommandName )
520+ compare := func (expected * cmdFailedEvt ) error {
521+ if len (expected .Extra ) > 0 {
522+ return fmt .Errorf ("unrecognized fields for CommandFailedEvent: %v" , expected .Extra )
523+ }
524+
525+ var evt * event.CommandFailedEvent
526+ // skip events not in expectations
527+ for {
528+ evt = mt .GetFailedEvent ()
529+ if evt == nil {
530+ return fmt .Errorf ("expected CommandFailedEvent %s, got nil" , expected .CommandName )
531+ }
532+ if expected .CommandName == "" {
533+ break
534+ } else if v , ok := expectedCmds [evt .CommandName ]; ok && v {
535+ break
536+ }
537+ }
538+
539+ if expected .CommandName != "" && expected .CommandName != evt .CommandName {
540+ return fmt .Errorf ("command name mismatch for failed event; expected %s, got %s" , expected .CommandName , evt .CommandName )
541+ }
542+ return nil
543+ }
544+ for idx , expected := range expectations {
545+ err := compare (expected )
546+ if err != nil {
547+ return fmt .Errorf ("error at index %d: %s" , idx , err )
548+ }
473549 }
474550 return nil
475551}
0 commit comments