From f7931251502d212815e3eb94581ccdfc1d33ad2e Mon Sep 17 00:00:00 2001 From: Igor Macedo Quintanilha Date: Mon, 24 Nov 2025 10:58:25 +0000 Subject: [PATCH] fix: context propagation fifo --- .../operations/AbstractMessagingTemplate.java | 22 +- ...SqsTemplateFifoTracingIntegrationTest.java | 239 ++++++++++++++++++ 2 files changed, 252 insertions(+), 9 deletions(-) create mode 100644 spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsTemplateFifoTracingIntegrationTest.java diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/operations/AbstractMessagingTemplate.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/operations/AbstractMessagingTemplate.java index bb6ed9e50..e2a2986b9 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/operations/AbstractMessagingTemplate.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/operations/AbstractMessagingTemplate.java @@ -301,8 +301,16 @@ public CompletableFuture> sendAsync(@Nullable String endpointN public CompletableFuture> sendAsync(@Nullable String endpointName, Message message) { String endpointToUse = getEndpointName(endpointName); logger.trace("Sending message {} to endpoint {}", MessageHeaderUtils.getId(message), endpointName); - return preProcessMessageForSendAsync(endpointToUse, message) - .thenCompose(messageToUse -> observeAndSendAsync(messageToUse, endpointToUse) + + // Create observation and add trace headers BEFORE async preprocessing + // This ensures trace context is captured on the calling thread + var context = this.observationSpecifics.createContext(message, endpointToUse); + var observation = startObservation(context); + var carrier = Objects.requireNonNull(context.getCarrier(), "No carrier found in context."); + var messageWithObservationHeaders = MessageHeaderUtils.addHeadersIfAbsent(message, carrier); + + return preProcessMessageForSendAsync(endpointToUse, messageWithObservationHeaders).thenCompose( + messageToUse -> doSendAndCompleteObservation(messageToUse, endpointToUse, context, observation) .exceptionallyCompose( t -> CompletableFuture.failedFuture(new MessagingOperationFailedException( "Message send operation failed for message %s to endpoint %s" @@ -311,13 +319,9 @@ public CompletableFuture> sendAsync(@Nullable String endpointN .whenComplete((v, t) -> logSendMessageResult(endpointToUse, message, t))); } - private CompletableFuture> observeAndSendAsync(Message message, String endpointToUse) { - AbstractTemplateObservation.Context context = this.observationSpecifics.createContext(message, endpointToUse); - Observation observation = startObservation(context); - Map carrier = Objects.requireNonNull(context.getCarrier(), "No carrier found in context."); - Message messageWithObservationHeader = MessageHeaderUtils.addHeadersIfAbsent(message, carrier); - return doSendAsync(endpointToUse, convertMessageToSend(messageWithObservationHeader), - messageWithObservationHeader) + private CompletableFuture> doSendAndCompleteObservation(Message message, String endpointToUse, + AbstractTemplateObservation.Context context, Observation observation) { + return doSendAsync(endpointToUse, convertMessageToSend(message), message) .whenComplete((sendResult, t) -> completeObservation(sendResult, context, t, observation)); } diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsTemplateFifoTracingIntegrationTest.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsTemplateFifoTracingIntegrationTest.java new file mode 100644 index 000000000..f03b8f3ef --- /dev/null +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsTemplateFifoTracingIntegrationTest.java @@ -0,0 +1,239 @@ +/* + * Copyright 2013-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.integration; + +import io.awspring.cloud.sqs.operations.SqsTemplate; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.tracing.CurrentTraceContext; +import io.micrometer.tracing.Span; +import io.micrometer.tracing.TraceContext; +import io.micrometer.tracing.Tracer; +import io.micrometer.tracing.handler.DefaultTracingObservationHandler; +import io.micrometer.tracing.handler.PropagatingReceiverTracingObservationHandler; +import io.micrometer.tracing.handler.PropagatingSenderTracingObservationHandler; +import io.micrometer.tracing.propagation.Propagator; +import io.micrometer.tracing.test.simple.SimpleTraceContext; +import io.micrometer.tracing.test.simple.SimpleTracer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import software.amazon.awssdk.services.sqs.SqsAsyncClient; +import software.amazon.awssdk.services.sqs.model.QueueAttributeName; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for trace context propagation in FIFO queues with SqsTemplate. + *

+ * Verifies that trace headers (traceparent) are correctly propagated from sender to receiver when using + * {@code sendAsync()} with FIFO queues, including scenarios where queue attributes must be resolved asynchronously on + * the first call and when they are cached on subsequent calls. + * + * @author Igor Quintanilha + */ +@SpringBootTest +public class SqsTemplateFifoTracingIntegrationTest extends BaseSqsIntegrationTest { + private static final Logger logger = LoggerFactory.getLogger(SqsTemplateFifoTracingIntegrationTest.class); + + private static final String FIFO_QUEUE_NAME = "trace-context-test-queue.fifo"; + private static final String FIFO_CACHE_HIT_QUEUE_NAME = "trace-context-test-queue-cache-hit.fifo"; + + @Autowired + private SqsTemplate sqsTemplate; + + @Autowired + private TestObservationRegistry observationRegistry; + + @Autowired + private CurrentTraceContext currentTraceContext; + + @BeforeAll + static void beforeTests() { + var client = createAsyncClient(); + createFifoQueue(client, FIFO_QUEUE_NAME, Map.of(QueueAttributeName.CONTENT_BASED_DEDUPLICATION, "false")).join(); + createFifoQueue(client, FIFO_CACHE_HIT_QUEUE_NAME, Map.of(QueueAttributeName.CONTENT_BASED_DEDUPLICATION, "true")).join(); + + } + + @AfterEach + void cleanupAfterEach() { + observationRegistry.clear(); + } + + @Test + void sendAsync_toFifoQueue_shouldPropagateObservationScopeOnFirstCall() { + var parentObservation = Observation.start("parent-observation", observationRegistry); + var payload = new TestEvent(UUID.randomUUID().toString()); + String expectedTraceId; + + try (var ignored = parentObservation.openScope()) { + expectedTraceId = currentTraceContext.context().traceId(); + sqsTemplate.sendAsync(FIFO_QUEUE_NAME, payload).join(); + } + finally { + parentObservation.stop(); + } + + logger.info("expectedTraceId={}", expectedTraceId); + + var receivedMessage = sqsTemplate + .receive(from -> from.queue(FIFO_QUEUE_NAME).pollTimeout(Duration.ofSeconds(5)), TestEvent.class) + .orElseThrow(() -> new AssertionError("Expected message was not received")); + + assertThat(receivedMessage.getPayload()).isEqualTo(payload); + var traceparent = (String) receivedMessage.getHeaders().get("traceparent"); + assertThat(traceparent).as("traceparent header should be present").isNotNull(); + assertThat(traceparent).as("traceparent should contain the traceId").contains(expectedTraceId); + } + + @Test + void sendAsync_toFifoQueue_shouldCreateObservationOnCallingThreadAfterCacheHit() { + // Given - Warm up: send a message to populate the queue attribute cache + var warmupPayload = new TestEvent(UUID.randomUUID().toString()); + sqsTemplate.sendAsync(FIFO_CACHE_HIT_QUEUE_NAME, warmupPayload).join(); + + // Drain the warmup message + sqsTemplate.receive(from -> from.queue(FIFO_CACHE_HIT_QUEUE_NAME).pollTimeout(Duration.ofSeconds(5)), TestEvent.class); + + // Given - Start a NEW observation for the actual test + var observation = Observation.start("test-send-second", observationRegistry); + String expectedTraceId; + + var payload = new TestEvent(UUID.randomUUID().toString()); + try (var ignored = observation.openScope()) { + expectedTraceId = currentTraceContext.context().traceId(); + // When - Second call (cache hit - queue attributes already resolved) + sqsTemplate.sendAsync(FIFO_CACHE_HIT_QUEUE_NAME, payload).join(); + } + finally { + observation.stop(); + } + + logger.info("expectedTraceId={}", expectedTraceId); + + var receivedMessage = sqsTemplate + .receive(from -> from.queue(FIFO_CACHE_HIT_QUEUE_NAME).pollTimeout(Duration.ofSeconds(5)), TestEvent.class) + .orElseThrow(() -> new AssertionError("Expected message was not received")); + + assertThat(receivedMessage.getPayload()).isEqualTo(payload); + var traceparent = (String) receivedMessage.getHeaders().get("traceparent"); + assertThat(traceparent).as("traceparent header should be present").isNotNull(); + assertThat(traceparent).as("traceparent should contain the traceId").contains(expectedTraceId); + } + + @Configuration + static class TestConfiguration { + + @Bean + public SqsAsyncClient sqsAsyncClient() { + return createAsyncClient(); + } + + @Bean + public Tracer tracer() { + return new SimpleTracer(); + } + + @Bean + public CurrentTraceContext currentTraceContext(Tracer tracer) { + return ((SimpleTracer) tracer).currentTraceContext(); + } + + @Bean + public Propagator propagator(Tracer tracer) { + return new SimplePropagator(tracer); + } + + @Bean + public ObservationRegistry observationRegistry(Tracer tracer, Propagator propagator) { + TestObservationRegistry registry = TestObservationRegistry.create(); + registry.observationConfig().observationHandler(new DefaultTracingObservationHandler(tracer)); + registry.observationConfig() + .observationHandler(new PropagatingSenderTracingObservationHandler<>(tracer, propagator)); + registry.observationConfig() + .observationHandler(new PropagatingReceiverTracingObservationHandler<>(tracer, propagator)); + return registry; + } + + @Bean + public SqsTemplate sqsTemplate(SqsAsyncClient sqsAsyncClient, ObservationRegistry observationRegistry) { + return SqsTemplate.builder().sqsAsyncClient(sqsAsyncClient) + .configure(options -> options.observationRegistry(observationRegistry)).build(); + } + } + + /** + * Simple W3C Trace Context propagator for testing. In production, you would use a library like + * micrometer-tracing-bridge-brave or micrometer-tracing-bridge-otel which provide full-featured propagators. + */ + static class SimplePropagator implements Propagator { + + private final Tracer tracer; + + SimplePropagator(Tracer tracer) { + this.tracer = tracer; + } + + @Override + public List fields() { + return List.of("traceparent", "tracestate"); + } + + @Override + public void inject(TraceContext context, C carrier, Setter setter) { + // W3C Trace Context format: version-traceId-spanId-flags + var traceparent = String.format("00-%s-%s-01", context.traceId(), context.spanId()); + setter.set(carrier, "traceparent", traceparent); + } + + @Override + public Span.Builder extract(C carrier, Getter getter) { + var traceparent = getter.get(carrier, "traceparent"); + if (traceparent == null || traceparent.isEmpty()) { + return tracer.spanBuilder().setNoParent(); + } + // Parse W3C format: 00-traceId-spanId-01 + String[] parts = traceparent.split("-"); + if (parts.length < 4) { + return tracer.spanBuilder().setNoParent(); + } + // Use tracer to create span builder with extracted context + Span.Builder builder = tracer.spanBuilder(); + var traceContext = new SimpleTraceContext(); + traceContext.setTraceId(parts[1]); + traceContext.setParentId(parts[2]); + traceContext.setSpanId(parts[3]); + builder.setParent(traceContext); + return builder; + } + } + + record TestEvent(String data) { + } +}