Skip to content

Commit e25fafe

Browse files
author
Igor Macedo Quintanilha
committed
fix: context propagation fifo
1 parent 926f68c commit e25fafe

File tree

3 files changed

+278
-9
lines changed

3 files changed

+278
-9
lines changed

spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/operations/AbstractMessagingTemplate.java

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,16 @@ public <T> CompletableFuture<SendResult<T>> sendAsync(@Nullable String endpointN
301301
public <T> CompletableFuture<SendResult<T>> sendAsync(@Nullable String endpointName, Message<T> message) {
302302
String endpointToUse = getEndpointName(endpointName);
303303
logger.trace("Sending message {} to endpoint {}", MessageHeaderUtils.getId(message), endpointName);
304-
return preProcessMessageForSendAsync(endpointToUse, message)
305-
.thenCompose(messageToUse -> observeAndSendAsync(messageToUse, endpointToUse)
304+
305+
// Create observation and add trace headers BEFORE async preprocessing
306+
// This ensures trace context is captured on the calling thread
307+
var context = this.observationSpecifics.createContext(message, endpointToUse);
308+
var observation = startObservation(context);
309+
var carrier = Objects.requireNonNull(context.getCarrier(), "No carrier found in context.");
310+
var messageWithObservationHeaders = MessageHeaderUtils.addHeadersIfAbsent(message, carrier);
311+
312+
return preProcessMessageForSendAsync(endpointToUse, messageWithObservationHeaders).thenCompose(
313+
messageToUse -> doSendAndCompleteObservation(messageToUse, endpointToUse, context, observation)
306314
.exceptionallyCompose(
307315
t -> CompletableFuture.failedFuture(new MessagingOperationFailedException(
308316
"Message send operation failed for message %s to endpoint %s"
@@ -311,13 +319,9 @@ public <T> CompletableFuture<SendResult<T>> sendAsync(@Nullable String endpointN
311319
.whenComplete((v, t) -> logSendMessageResult(endpointToUse, message, t)));
312320
}
313321

314-
private <T> CompletableFuture<SendResult<T>> observeAndSendAsync(Message<T> message, String endpointToUse) {
315-
AbstractTemplateObservation.Context context = this.observationSpecifics.createContext(message, endpointToUse);
316-
Observation observation = startObservation(context);
317-
Map<String, Object> carrier = Objects.requireNonNull(context.getCarrier(), "No carrier found in context.");
318-
Message<T> messageWithObservationHeader = MessageHeaderUtils.addHeadersIfAbsent(message, carrier);
319-
return doSendAsync(endpointToUse, convertMessageToSend(messageWithObservationHeader),
320-
messageWithObservationHeader)
322+
private <T> CompletableFuture<SendResult<T>> doSendAndCompleteObservation(Message<T> message, String endpointToUse,
323+
AbstractTemplateObservation.Context context, Observation observation) {
324+
return doSendAsync(endpointToUse, convertMessageToSend(message), message)
321325
.whenComplete((sendResult, t) -> completeObservation(sendResult, context, t, observation));
322326
}
323327

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
return preProcessMessageForSendAsync(endpointToUse, message)
2+
.thenCompose(messageToUse -> observeAndSendAsync(messageToUse, endpointToUse)
3+
.exceptionallyCompose(
4+
t -> CompletableFuture.failedFuture(new MessagingOperationFailedException(
5+
"Message send operation failed for message %s to endpoint %s"
6+
.formatted(MessageHeaderUtils.getId(message), endpointToUse),
7+
endpointToUse, message, t)))
8+
.whenComplete((v, t) -> logSendMessageResult(endpointToUse, message, t)));
9+
10+
11+
12+
13+
/**
14+
* @deprecated Use {@link #doSendAndCompleteObservation} instead. This method creates a new observation which may
15+
* capture trace context on the wrong thread if called after async operations.
16+
*/
17+
@Deprecated
18+
private <T> CompletableFuture<SendResult<T>> observeAndSendAsync(Message<T> message, String endpointToUse) {
19+
AbstractTemplateObservation.Context context = this.observationSpecifics.createContext(message, endpointToUse);
20+
Observation observation = startObservation(context);
21+
Map<String, Object> carrier = Objects.requireNonNull(context.getCarrier(), "No carrier found in context.");
22+
Message<T> messageWithObservationHeader = MessageHeaderUtils.addHeadersIfAbsent(message, carrier);
23+
return doSendAsync(endpointToUse, convertMessageToSend(messageWithObservationHeader),
24+
messageWithObservationHeader)
25+
.whenComplete((sendResult, t) -> completeObservation(sendResult, context, t, observation));
26+
}
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
/*
2+
* Copyright 2013-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package io.awspring.cloud.sqs.integration;
17+
18+
import io.awspring.cloud.sqs.operations.SqsTemplate;
19+
import io.micrometer.observation.Observation;
20+
import io.micrometer.observation.ObservationRegistry;
21+
import io.micrometer.observation.tck.TestObservationRegistry;
22+
import io.micrometer.tracing.CurrentTraceContext;
23+
import io.micrometer.tracing.Span;
24+
import io.micrometer.tracing.TraceContext;
25+
import io.micrometer.tracing.Tracer;
26+
import io.micrometer.tracing.handler.DefaultTracingObservationHandler;
27+
import io.micrometer.tracing.handler.PropagatingReceiverTracingObservationHandler;
28+
import io.micrometer.tracing.handler.PropagatingSenderTracingObservationHandler;
29+
import io.micrometer.tracing.propagation.Propagator;
30+
import io.micrometer.tracing.test.simple.SimpleTraceContext;
31+
import io.micrometer.tracing.test.simple.SimpleTracer;
32+
import org.junit.jupiter.api.AfterEach;
33+
import org.junit.jupiter.api.BeforeAll;
34+
import org.junit.jupiter.api.Test;
35+
import org.slf4j.Logger;
36+
import org.slf4j.LoggerFactory;
37+
import org.springframework.beans.factory.annotation.Autowired;
38+
import org.springframework.boot.test.context.SpringBootTest;
39+
import org.springframework.context.annotation.Bean;
40+
import org.springframework.context.annotation.Configuration;
41+
import software.amazon.awssdk.services.sqs.SqsAsyncClient;
42+
import software.amazon.awssdk.services.sqs.model.QueueAttributeName;
43+
44+
import java.time.Duration;
45+
import java.util.List;
46+
import java.util.Map;
47+
import java.util.UUID;
48+
49+
import static org.assertj.core.api.Assertions.assertThat;
50+
51+
/**
52+
* Integration tests for trace context propagation in FIFO queues with SqsTemplate.
53+
* <p>
54+
* Verifies that trace headers (traceparent) are correctly propagated from sender to receiver when using
55+
* {@code sendAsync()} with FIFO queues, including scenarios where queue attributes must be resolved asynchronously on
56+
* the first call and when they are cached on subsequent calls.
57+
*
58+
* @author Igor Quintanilha
59+
*/
60+
@SpringBootTest
61+
public class SqsTemplateFifoTracingIntegrationTest extends BaseSqsIntegrationTest {
62+
private static final Logger logger = LoggerFactory.getLogger(SqsTemplateFifoTracingIntegrationTest.class);
63+
64+
private static final String FIFO_QUEUE_NAME = "trace-context-test-queue.fifo";
65+
private static final String FIFO_CACHE_HIT_QUEUE_NAME = "trace-context-test-queue-cache-hit.fifo";
66+
67+
@Autowired
68+
private SqsTemplate sqsTemplate;
69+
70+
@Autowired
71+
private TestObservationRegistry observationRegistry;
72+
73+
@Autowired
74+
private CurrentTraceContext currentTraceContext;
75+
76+
@BeforeAll
77+
static void beforeTests() {
78+
var client = createAsyncClient();
79+
createFifoQueue(client, FIFO_QUEUE_NAME, Map.of(QueueAttributeName.CONTENT_BASED_DEDUPLICATION, "false")).join();
80+
createFifoQueue(client, FIFO_CACHE_HIT_QUEUE_NAME, Map.of(QueueAttributeName.CONTENT_BASED_DEDUPLICATION, "true")).join();
81+
82+
}
83+
84+
@AfterEach
85+
void cleanupAfterEach() {
86+
observationRegistry.clear();
87+
}
88+
89+
@Test
90+
void sendAsync_toFifoQueue_shouldPropagateObservationScopeOnFirstCall() {
91+
var parentObservation = Observation.start("parent-observation", observationRegistry);
92+
var payload = new TestEvent(UUID.randomUUID().toString());
93+
String expectedTraceId;
94+
95+
try (var ignored = parentObservation.openScope()) {
96+
expectedTraceId = currentTraceContext.context().traceId();
97+
sqsTemplate.sendAsync(FIFO_QUEUE_NAME, payload).join();
98+
}
99+
finally {
100+
parentObservation.stop();
101+
}
102+
103+
logger.info("expectedTraceId={}", expectedTraceId);
104+
105+
var receivedMessage = sqsTemplate
106+
.receive(from -> from.queue(FIFO_QUEUE_NAME).pollTimeout(Duration.ofSeconds(5)), TestEvent.class)
107+
.orElseThrow(() -> new AssertionError("Expected message was not received"));
108+
109+
assertThat(receivedMessage.getPayload()).isEqualTo(payload);
110+
var traceparent = (String) receivedMessage.getHeaders().get("traceparent");
111+
assertThat(traceparent).as("traceparent header should be present").isNotNull();
112+
assertThat(traceparent).as("traceparent should contain the traceId").contains(expectedTraceId);
113+
}
114+
115+
@Test
116+
void sendAsync_toFifoQueue_shouldCreateObservationOnCallingThreadAfterCacheHit() {
117+
// Given - Warm up: send a message to populate the queue attribute cache
118+
var warmupPayload = new TestEvent(UUID.randomUUID().toString());
119+
sqsTemplate.sendAsync(FIFO_CACHE_HIT_QUEUE_NAME, warmupPayload).join();
120+
121+
// Drain the warmup message
122+
sqsTemplate.receive(from -> from.queue(FIFO_CACHE_HIT_QUEUE_NAME).pollTimeout(Duration.ofSeconds(5)), TestEvent.class);
123+
124+
// Given - Start a NEW observation for the actual test
125+
var observation = Observation.start("test-send-second", observationRegistry);
126+
String expectedTraceId;
127+
128+
var payload = new TestEvent(UUID.randomUUID().toString());
129+
try (var ignored = observation.openScope()) {
130+
expectedTraceId = currentTraceContext.context().traceId();
131+
// When - Second call (cache hit - queue attributes already resolved)
132+
sqsTemplate.sendAsync(FIFO_CACHE_HIT_QUEUE_NAME, payload).join();
133+
}
134+
finally {
135+
observation.stop();
136+
}
137+
138+
logger.info("expectedTraceId={}", expectedTraceId);
139+
140+
var receivedMessage = sqsTemplate
141+
.receive(from -> from.queue(FIFO_CACHE_HIT_QUEUE_NAME).pollTimeout(Duration.ofSeconds(5)), TestEvent.class)
142+
.orElseThrow(() -> new AssertionError("Expected message was not received"));
143+
144+
assertThat(receivedMessage.getPayload()).isEqualTo(payload);
145+
var traceparent = (String) receivedMessage.getHeaders().get("traceparent");
146+
assertThat(traceparent).as("traceparent header should be present").isNotNull();
147+
assertThat(traceparent).as("traceparent should contain the traceId").contains(expectedTraceId);
148+
}
149+
150+
@Configuration
151+
static class TestConfiguration {
152+
153+
@Bean
154+
public SqsAsyncClient sqsAsyncClient() {
155+
return createAsyncClient();
156+
}
157+
158+
@Bean
159+
public Tracer tracer() {
160+
return new SimpleTracer();
161+
}
162+
163+
@Bean
164+
public CurrentTraceContext currentTraceContext(Tracer tracer) {
165+
return ((SimpleTracer) tracer).currentTraceContext();
166+
}
167+
168+
@Bean
169+
public Propagator propagator(Tracer tracer) {
170+
return new SimplePropagator(tracer);
171+
}
172+
173+
@Bean
174+
public ObservationRegistry observationRegistry(Tracer tracer, Propagator propagator) {
175+
TestObservationRegistry registry = TestObservationRegistry.create();
176+
registry.observationConfig().observationHandler(new DefaultTracingObservationHandler(tracer));
177+
registry.observationConfig()
178+
.observationHandler(new PropagatingSenderTracingObservationHandler<>(tracer, propagator));
179+
registry.observationConfig()
180+
.observationHandler(new PropagatingReceiverTracingObservationHandler<>(tracer, propagator));
181+
return registry;
182+
}
183+
184+
@Bean
185+
public SqsTemplate sqsTemplate(SqsAsyncClient sqsAsyncClient, ObservationRegistry observationRegistry) {
186+
return SqsTemplate.builder().sqsAsyncClient(sqsAsyncClient)
187+
.configure(options -> options.observationRegistry(observationRegistry)).build();
188+
}
189+
}
190+
191+
/**
192+
* Simple W3C Trace Context propagator for testing. In production, you would use a library like
193+
* micrometer-tracing-bridge-brave or micrometer-tracing-bridge-otel which provide full-featured propagators.
194+
*/
195+
static class SimplePropagator implements Propagator {
196+
197+
private final Tracer tracer;
198+
199+
SimplePropagator(Tracer tracer) {
200+
this.tracer = tracer;
201+
}
202+
203+
@Override
204+
public List<String> fields() {
205+
return List.of("traceparent", "tracestate");
206+
}
207+
208+
@Override
209+
public <C> void inject(TraceContext context, C carrier, Setter<C> setter) {
210+
// W3C Trace Context format: version-traceId-spanId-flags
211+
var traceparent = String.format("00-%s-%s-01", context.traceId(), context.spanId());
212+
setter.set(carrier, "traceparent", traceparent);
213+
}
214+
215+
@Override
216+
public <C> Span.Builder extract(C carrier, Getter<C> getter) {
217+
var traceparent = getter.get(carrier, "traceparent");
218+
if (traceparent == null || traceparent.isEmpty()) {
219+
return tracer.spanBuilder().setNoParent();
220+
}
221+
// Parse W3C format: 00-traceId-spanId-01
222+
String[] parts = traceparent.split("-");
223+
if (parts.length < 4) {
224+
return tracer.spanBuilder().setNoParent();
225+
}
226+
// Use tracer to create span builder with extracted context
227+
Span.Builder builder = tracer.spanBuilder();
228+
var traceContext = new SimpleTraceContext();
229+
traceContext.setTraceId(parts[1]);
230+
traceContext.setParentId(parts[2]);
231+
traceContext.setSpanId(parts[3]);
232+
builder.setParent(traceContext);
233+
return builder;
234+
}
235+
}
236+
237+
record TestEvent(String data) {
238+
}
239+
}

0 commit comments

Comments
 (0)