1717package org .mongodb .scala .internal
1818
1919import java .util .concurrent .ConcurrentLinkedQueue
20+ import java .util .concurrent .atomic .AtomicLong
2021
2122import org .mongodb .scala .{ Observable , Observer , Subscription }
2223
23- private [scala] case class ZipObservable [T , U ](
24- observable1 : Observable [T ],
25- observable2 : Observable [U ]
26- ) extends Observable [(T , U )] {
24+ private [scala] case class ZipObservable [L , R ](
25+ leftObservable : Observable [L ],
26+ rightObservable : Observable [R ]
27+ ) extends Observable [(L , R )] {
2728
28- def subscribe (observer : Observer [_ >: (T , U )]): Unit = {
29+ def subscribe (observer : Observer [_ >: (L , R )]): Unit = {
2930 val helper = SubscriptionHelper (observer)
30- observable1 .subscribe(SubscriptionCheckingObserver (helper.createFirstObserver ))
31- observable2 .subscribe(SubscriptionCheckingObserver (helper.createSecondObserver ))
31+ leftObservable .subscribe(SubscriptionCheckingObserver (helper.createLeftObserver ))
32+ rightObservable .subscribe(SubscriptionCheckingObserver (helper.createRightObserver ))
3233 }
3334
34- case class SubscriptionHelper (observer : Observer [_ >: (T , U )]) {
35- private val thisQueue : ConcurrentLinkedQueue [(Long , T )] = new ConcurrentLinkedQueue [(Long , T )]()
36- private val thatQueue : ConcurrentLinkedQueue [(Long , U )] = new ConcurrentLinkedQueue [(Long , U )]()
35+ case class SubscriptionHelper (observer : Observer [_ >: (L , R )]) {
36+ private val leftQueue : ConcurrentLinkedQueue [(Long , L )] = new ConcurrentLinkedQueue [(Long , L )]()
37+ private val rightQueue : ConcurrentLinkedQueue [(Long , R )] = new ConcurrentLinkedQueue [(Long , R )]()
3738
39+ private val leftCounter : AtomicLong = new AtomicLong ()
40+ private val rightCounter : AtomicLong = new AtomicLong ()
41+ @ volatile private var completedLeft : Boolean = false
42+ @ volatile private var completedRight : Boolean = false
3843 @ volatile private var terminated : Boolean = false
39- @ volatile private var observable1Subscription : Option [Subscription ] = None
40- @ volatile private var observable2Subscription : Option [Subscription ] = None
44+ @ volatile private var leftSubscription : Option [Subscription ] = None
45+ @ volatile private var rightSubscription : Option [Subscription ] = None
4146
42- def createFirstObserver : Observer [T ] = createSubObserver[T ](thisQueue, observer, firstSub = true )
43-
44- def createSecondObserver : Observer [U ] = createSubObserver[U ](thatQueue, observer, firstSub = false )
47+ def createLeftObserver : Observer [L ] = createSubObserver[L ](leftQueue, observer, isLeftSub = true )
48+ def createRightObserver : Observer [R ] = createSubObserver[R ](rightQueue, observer, isLeftSub = false )
4549
4650 private def createSubObserver [A ](
4751 queue : ConcurrentLinkedQueue [(Long , A )],
48- observer : Observer [_ >: (T , U )],
49- firstSub : Boolean
52+ observer : Observer [_ >: (L , R )],
53+ isLeftSub : Boolean
5054 ): Observer [A ] = {
5155 new Observer [A ] {
5256 @ volatile private var counter : Long = 0
@@ -56,38 +60,61 @@ private[scala] case class ZipObservable[T, U](
5660 }
5761
5862 override def onSubscribe (subscription : Subscription ): Unit = {
59- if (firstSub ) {
60- observable1Subscription = Some (subscription)
63+ if (isLeftSub ) {
64+ leftSubscription = Some (subscription)
6165 } else {
62- observable2Subscription = Some (subscription)
66+ rightSubscription = Some (subscription)
6367 }
6468
65- if (observable1Subscription .nonEmpty && observable2Subscription .nonEmpty) {
69+ if (leftSubscription .nonEmpty && rightSubscription .nonEmpty) {
6670 observer.onSubscribe(jointSubscription)
6771 }
6872 }
6973
7074 override def onComplete (): Unit = {
71- if (! firstSub) {
72- terminated = true
73- observer.onComplete()
74- }
75+ markCompleted(isLeftSub)
76+ processNext(observer)
7577 }
7678
7779 override def onNext (tResult : A ): Unit = {
80+ if (isLeftSub) leftCounter.incrementAndGet() else rightCounter.incrementAndGet()
7881 counter += 1
7982 queue.add((counter, tResult))
80- if ( ! firstSub) processNext(observer)
83+ processNext(observer)
8184 }
8285 }
8386 }
8487
85- private def processNext (observer : Observer [_ >: (T , U )]): Unit = {
86- (thisQueue.peek, thatQueue.peek) match {
87- case ((k1 : Long , _), (k2 : Long , _)) if k1 == k2 => observer.onNext((thisQueue.poll()._2, thatQueue.poll()._2))
88+ private def markCompleted (isLeftSub : Boolean ): Unit = synchronized {
89+ if (isLeftSub) {
90+ completedLeft = true
91+ } else {
92+ completedRight = true
93+ }
94+ }
95+
96+ private def completed (): Unit = synchronized {
97+ if (! terminated) {
98+ terminated = true
99+ leftSubscription.foreach(_.unsubscribe())
100+ rightSubscription.foreach(_.unsubscribe())
101+ observer.onComplete()
102+ }
103+ }
104+
105+ private def processNext (observer : Observer [_ >: (L , R )]): Unit = synchronized {
106+ (leftQueue.peek, rightQueue.peek) match {
107+ case ((k1 : Long , _), (k2 : Long , _)) if k1 == k2 =>
108+ observer.onNext((leftQueue.poll()._2, rightQueue.poll()._2))
109+ processNext(observer)
88110 case _ =>
89- if (! terminated && ! jointSubscription.isUnsubscribed) jointSubscription.request(1 ) // Uneven queues request more data
90- // from downstream so to honor the original request for data.
111+ if (! terminated && ! jointSubscription.isUnsubscribed) {
112+ if (completedLeft && rightCounter.get() >= leftCounter.get()) {
113+ completed()
114+ } else if (completedRight && leftCounter.get() >= rightCounter.get()) {
115+ completed()
116+ }
117+ }
91118 }
92119 }
93120
@@ -96,14 +123,14 @@ private[scala] case class ZipObservable[T, U](
96123 override def isUnsubscribed : Boolean = ! subscribed
97124
98125 override def request (n : Long ): Unit = {
99- observable1Subscription .foreach(_.request(n))
100- observable2Subscription .foreach(_.request(n))
126+ leftSubscription .foreach(_.request(n))
127+ rightSubscription .foreach(_.request(n))
101128 }
102129
103130 override def unsubscribe (): Unit = {
104131 subscribed = false
105- observable1Subscription .foreach(_.unsubscribe())
106- observable2Subscription .foreach(_.unsubscribe())
132+ leftSubscription .foreach(_.unsubscribe())
133+ rightSubscription .foreach(_.unsubscribe())
107134 }
108135 }
109136 }
0 commit comments