@@ -155,11 +155,13 @@ CompletableFuture<V> load(K key, Object loadContext) {
155155 }
156156 }
157157
158+ @ SuppressWarnings ("unchecked" )
158159 Object getCacheKey (K key ) {
159160 return loaderOptions .cacheKeyFunction ().isPresent () ?
160161 loaderOptions .cacheKeyFunction ().get ().getKey (key ) : key ;
161162 }
162163
164+ @ SuppressWarnings ("unchecked" )
163165 Object getCacheKeyWithContext (K key , Object context ) {
164166 return loaderOptions .cacheKeyFunction ().isPresent () ?
165167 loaderOptions .cacheKeyFunction ().get ().getKeyWithContext (key , context ) : key ;
@@ -511,6 +513,7 @@ private CompletableFuture<List<V>> invokeBatchPublisher(List<K> keys, List<Objec
511513
512514 BatchLoaderScheduler batchLoaderScheduler = loaderOptions .getBatchLoaderScheduler ();
513515 if (batchLoadFunction instanceof BatchPublisherWithContext ) {
516+ //noinspection unchecked
514517 BatchPublisherWithContext <K , V > loadFunction = (BatchPublisherWithContext <K , V >) batchLoadFunction ;
515518 if (batchLoaderScheduler != null ) {
516519 BatchLoaderScheduler .ScheduledBatchPublisherCall loadCall = () -> loadFunction .load (keys , subscriber , environment );
@@ -519,6 +522,7 @@ private CompletableFuture<List<V>> invokeBatchPublisher(List<K> keys, List<Objec
519522 loadFunction .load (keys , subscriber , environment );
520523 }
521524 } else {
525+ //noinspection unchecked
522526 BatchPublisher <K , V > loadFunction = (BatchPublisher <K , V >) batchLoadFunction ;
523527 if (batchLoaderScheduler != null ) {
524528 BatchLoaderScheduler .ScheduledBatchPublisherCall loadCall = () -> loadFunction .load (keys , subscriber );
@@ -536,6 +540,7 @@ private CompletableFuture<List<V>> invokeMappedBatchPublisher(List<K> keys, List
536540
537541 BatchLoaderScheduler batchLoaderScheduler = loaderOptions .getBatchLoaderScheduler ();
538542 if (batchLoadFunction instanceof MappedBatchPublisherWithContext ) {
543+ //noinspection unchecked
539544 MappedBatchPublisherWithContext <K , V > loadFunction = (MappedBatchPublisherWithContext <K , V >) batchLoadFunction ;
540545 if (batchLoaderScheduler != null ) {
541546 BatchLoaderScheduler .ScheduledBatchPublisherCall loadCall = () -> loadFunction .load (keys , subscriber , environment );
@@ -544,6 +549,7 @@ private CompletableFuture<List<V>> invokeMappedBatchPublisher(List<K> keys, List
544549 loadFunction .load (keys , subscriber , environment );
545550 }
546551 } else {
552+ //noinspection unchecked
547553 MappedBatchPublisher <K , V > loadFunction = (MappedBatchPublisher <K , V >) batchLoadFunction ;
548554 if (batchLoaderScheduler != null ) {
549555 BatchLoaderScheduler .ScheduledBatchPublisherCall loadCall = () -> loadFunction .load (keys , subscriber );
@@ -618,24 +624,23 @@ private static <T> DispatchResult<T> emptyDispatchResult() {
618624 return (DispatchResult <T >) EMPTY_DISPATCH_RESULT ;
619625 }
620626
621- private class DataLoaderSubscriber implements Subscriber <V > {
627+ private abstract class DataLoaderSubscriberBase < T > implements Subscriber <T > {
622628
623- private final CompletableFuture <List <V >> valuesFuture ;
624- private final List <K > keys ;
625- private final List <Object > callContexts ;
626- private final List <CompletableFuture <V >> queuedFutures ;
629+ final CompletableFuture <List <V >> valuesFuture ;
630+ final List <K > keys ;
631+ final List <Object > callContexts ;
632+ final List <CompletableFuture <V >> queuedFutures ;
627633
628- private final List <K > clearCacheKeys = new ArrayList <>();
629- private final List <V > completedValues = new ArrayList <>();
630- private int idx = 0 ;
631- private boolean onErrorCalled = false ;
632- private boolean onCompleteCalled = false ;
634+ List <K > clearCacheKeys = new ArrayList <>();
635+ List <V > completedValues = new ArrayList <>();
636+ boolean onErrorCalled = false ;
637+ boolean onCompleteCalled = false ;
633638
634- private DataLoaderSubscriber (
635- CompletableFuture <List <V >> valuesFuture ,
636- List <K > keys ,
637- List <Object > callContexts ,
638- List <CompletableFuture <V >> queuedFutures
639+ DataLoaderSubscriberBase (
640+ CompletableFuture <List <V >> valuesFuture ,
641+ List <K > keys ,
642+ List <Object > callContexts ,
643+ List <CompletableFuture <V >> queuedFutures
639644 ) {
640645 this .valuesFuture = valuesFuture ;
641646 this .keys = keys ;
@@ -648,55 +653,97 @@ public void onSubscribe(Subscription subscription) {
648653 subscription .request (keys .size ());
649654 }
650655
651- // onNext may be called by multiple threads - for the time being, we pass 'synchronized' to guarantee
652- // correctness (at the cost of speed).
653656 @ Override
654- public synchronized void onNext (V value ) {
657+ public void onNext (T v ) {
655658 assertState (!onErrorCalled , () -> "onError has already been called; onNext may not be invoked." );
656659 assertState (!onCompleteCalled , () -> "onComplete has already been called; onNext may not be invoked." );
660+ }
657661
658- K key = keys .get (idx );
659- Object callContext = callContexts .get (idx );
660- CompletableFuture <V > future = queuedFutures .get (idx );
662+ @ Override
663+ public void onComplete () {
664+ assertState (!onErrorCalled , () -> "onError has already been called; onComplete may not be invoked." );
665+ onCompleteCalled = true ;
666+ }
667+
668+ @ Override
669+ public void onError (Throwable throwable ) {
670+ assertState (!onCompleteCalled , () -> "onComplete has already been called; onError may not be invoked." );
671+ onErrorCalled = true ;
672+
673+ stats .incrementBatchLoadExceptionCount (new IncrementBatchLoadExceptionCountStatisticsContext <>(keys , callContexts ));
674+ }
675+
676+ /*
677+ * A value has arrived - how do we complete the future that's associated with it in a common way
678+ */
679+ void onNextValue (K key , V value , Object callContext , List <CompletableFuture <V >> futures ) {
661680 if (value instanceof Try ) {
662681 // we allow the batch loader to return a Try so we can better represent a computation
663682 // that might have worked or not.
683+ //noinspection unchecked
664684 Try <V > tryValue = (Try <V >) value ;
665685 if (tryValue .isSuccess ()) {
666- future . complete (tryValue .get ());
686+ futures . forEach ( f -> f . complete (tryValue .get () ));
667687 } else {
668688 stats .incrementLoadErrorCount (new IncrementLoadErrorCountStatisticsContext <>(key , callContext ));
669- future . completeExceptionally (tryValue .getThrowable ());
670- clearCacheKeys .add (keys . get ( idx ) );
689+ futures . forEach ( f -> f . completeExceptionally (tryValue .getThrowable () ));
690+ clearCacheKeys .add (key );
671691 }
672692 } else {
673- future . complete (value );
693+ futures . forEach ( f -> f . complete (value ) );
674694 }
695+ }
696+
697+ Throwable unwrapThrowable (Throwable ex ) {
698+ if (ex instanceof CompletionException ) {
699+ ex = ex .getCause ();
700+ }
701+ return ex ;
702+ }
703+ }
704+
705+ private class DataLoaderSubscriber extends DataLoaderSubscriberBase <V > {
706+
707+ private int idx = 0 ;
708+
709+ private DataLoaderSubscriber (
710+ CompletableFuture <List <V >> valuesFuture ,
711+ List <K > keys ,
712+ List <Object > callContexts ,
713+ List <CompletableFuture <V >> queuedFutures
714+ ) {
715+ super (valuesFuture , keys , callContexts , queuedFutures );
716+ }
717+
718+ // onNext may be called by multiple threads - for the time being, we pass 'synchronized' to guarantee
719+ // correctness (at the cost of speed).
720+ @ Override
721+ public synchronized void onNext (V value ) {
722+ super .onNext (value );
723+
724+ K key = keys .get (idx );
725+ Object callContext = callContexts .get (idx );
726+ CompletableFuture <V > future = queuedFutures .get (idx );
727+ onNextValue (key , value , callContext , List .of (future ));
675728
676729 completedValues .add (value );
677730 idx ++;
678731 }
679732
680- @ Override
681- public void onComplete () {
682- assertState (!onErrorCalled , () -> "onError has already been called; onComplete may not be invoked." );
683- onCompleteCalled = true ;
684733
734+ @ Override
735+ public synchronized void onComplete () {
736+ super .onComplete ();
685737 assertResultSize (keys , completedValues );
686738
687739 possiblyClearCacheEntriesOnExceptions (clearCacheKeys );
688740 valuesFuture .complete (completedValues );
689741 }
690742
691743 @ Override
692- public void onError (Throwable ex ) {
693- assertState (!onCompleteCalled , () -> "onComplete has already been called; onError may not be invoked." );
694- onErrorCalled = true ;
695-
696- stats .incrementBatchLoadExceptionCount (new IncrementBatchLoadExceptionCountStatisticsContext <>(keys , callContexts ));
697- if (ex instanceof CompletionException ) {
698- ex = ex .getCause ();
699- }
744+ public synchronized void onError (Throwable ex ) {
745+ super .onError (ex );
746+ ex = unwrapThrowable (ex );
700747 // Set the remaining keys to the exception.
701748 for (int i = idx ; i < queuedFutures .size (); i ++) {
702749 K key = keys .get (i );
@@ -705,33 +752,25 @@ public void onError(Throwable ex) {
705752 // clear any cached view of this key because they all failed
706753 dataLoader .clear (key );
707754 }
755+ valuesFuture .completeExceptionally (ex );
708756 }
757+
709758 }
710759
711- private class DataLoaderMapEntrySubscriber implements Subscriber <Map .Entry <K , V >> {
712- private final CompletableFuture <List <V >> valuesFuture ;
713- private final List <K > keys ;
714- private final List <Object > callContexts ;
715- private final List <CompletableFuture <V >> queuedFutures ;
760+ private class DataLoaderMapEntrySubscriber extends DataLoaderSubscriberBase <Map .Entry <K , V >> {
761+
716762 private final Map <K , Object > callContextByKey ;
717763 private final Map <K , List <CompletableFuture <V >>> queuedFuturesByKey ;
718-
719- private final List <K > clearCacheKeys = new ArrayList <>();
720764 private final Map <K , V > completedValuesByKey = new HashMap <>();
721- private boolean onErrorCalled = false ;
722- private boolean onCompleteCalled = false ;
765+
723766
724767 private DataLoaderMapEntrySubscriber (
725- CompletableFuture <List <V >> valuesFuture ,
726- List <K > keys ,
727- List <Object > callContexts ,
728- List <CompletableFuture <V >> queuedFutures
768+ CompletableFuture <List <V >> valuesFuture ,
769+ List <K > keys ,
770+ List <Object > callContexts ,
771+ List <CompletableFuture <V >> queuedFutures
729772 ) {
730- this .valuesFuture = valuesFuture ;
731- this .keys = keys ;
732- this .callContexts = callContexts ;
733- this .queuedFutures = queuedFutures ;
734-
773+ super (valuesFuture , keys , callContexts , queuedFutures );
735774 this .callContextByKey = new HashMap <>();
736775 this .queuedFuturesByKey = new HashMap <>();
737776 for (int idx = 0 ; idx < queuedFutures .size (); idx ++) {
@@ -743,42 +782,24 @@ private DataLoaderMapEntrySubscriber(
743782 }
744783 }
745784
746- @ Override
747- public void onSubscribe (Subscription subscription ) {
748- subscription .request (keys .size ());
749- }
750785
751786 @ Override
752- public void onNext (Map .Entry <K , V > entry ) {
753- assertState (!onErrorCalled , () -> "onError has already been called; onNext may not be invoked." );
754- assertState (!onCompleteCalled , () -> "onComplete has already been called; onNext may not be invoked." );
787+ public synchronized void onNext (Map .Entry <K , V > entry ) {
788+ super .onNext (entry );
755789 K key = entry .getKey ();
756790 V value = entry .getValue ();
757791
758792 Object callContext = callContextByKey .get (key );
759793 List <CompletableFuture <V >> futures = queuedFuturesByKey .get (key );
760- if (value instanceof Try ) {
761- // we allow the batch loader to return a Try so we can better represent a computation
762- // that might have worked or not.
763- Try <V > tryValue = (Try <V >) value ;
764- if (tryValue .isSuccess ()) {
765- futures .forEach (f -> f .complete (tryValue .get ()));
766- } else {
767- stats .incrementLoadErrorCount (new IncrementLoadErrorCountStatisticsContext <>(key , callContext ));
768- futures .forEach (f -> f .completeExceptionally (tryValue .getThrowable ()));
769- clearCacheKeys .add (key );
770- }
771- } else {
772- futures .forEach (f -> f .complete (value ));
773- }
794+
795+ onNextValue (key , value , callContext , futures );
774796
775797 completedValuesByKey .put (key , value );
776798 }
777799
778800 @ Override
779- public void onComplete () {
780- assertState (!onErrorCalled , () -> "onError has already been called; onComplete may not be invoked." );
781- onCompleteCalled = true ;
801+ public synchronized void onComplete () {
802+ super .onComplete ();
782803
783804 possiblyClearCacheEntriesOnExceptions (clearCacheKeys );
784805 List <V > values = new ArrayList <>(keys .size ());
@@ -790,14 +811,9 @@ public void onComplete() {
790811 }
791812
792813 @ Override
793- public void onError (Throwable ex ) {
794- assertState (!onCompleteCalled , () -> "onComplete has already been called; onError may not be invoked." );
795- onErrorCalled = true ;
796-
797- stats .incrementBatchLoadExceptionCount (new IncrementBatchLoadExceptionCountStatisticsContext <>(keys , callContexts ));
798- if (ex instanceof CompletionException ) {
799- ex = ex .getCause ();
800- }
814+ public synchronized void onError (Throwable ex ) {
815+ super .onError (ex );
816+ ex = unwrapThrowable (ex );
801817 // Complete the futures for the remaining keys with the exception.
802818 for (int idx = 0 ; idx < queuedFutures .size (); idx ++) {
803819 K key = keys .get (idx );
@@ -810,6 +826,7 @@ public void onError(Throwable ex) {
810826 dataLoader .clear (key );
811827 }
812828 }
829+ valuesFuture .completeExceptionally (ex );
813830 }
814831 }
815832}
0 commit comments