diff --git a/.changelog/1763408225.md b/.changelog/1763408225.md new file mode 100644 index 00000000000..649a2120f3d --- /dev/null +++ b/.changelog/1763408225.md @@ -0,0 +1,10 @@ +--- +applies_to: ["server"] +authors: ["rcoh"] +references: ["smithy-rs#4400", "smithy-rs#4397"] +breaking: true +new_feature: false +bug_fix: true +--- +Fix issue where SigV4 envelopes for EventStreams did not support the initial message. This is _technically_ a breaking change but should not break consumers in practice since the +resulting type has the same methods. diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt index 4b6a39fc6e4..cf6c44b5f68 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt @@ -96,12 +96,6 @@ sealed class HttpBindingSection(name: String) : Section(name) { data class AfterDeserializingIntoADateTimeOfHttpHeaders(val memberShape: MemberShape) : HttpBindingSection("AfterDeserializingIntoADateTimeOfHttpHeaders") - - data class BeforeCreatingEventStreamReceiver( - val operationShape: OperationShape, - val unionShape: UnionShape, - val unmarshallerVariableName: String, - ) : HttpBindingSection("BeforeCreatingEventStreamReceiver") } typealias HttpBindingCustomization = NamedCustomization @@ -282,21 +276,11 @@ class HttpBindingGenerator( "unmarshallerConstructorFn" to unmarshallerConstructorFn, ) - // Allow customizations to wrap the unmarshaller - for (customization in customizations) { - customization.section( - HttpBindingSection.BeforeCreatingEventStreamReceiver( - operationShape, - targetShape, - "unmarshaller", - ), - )(this) - } - rustTemplate( """ let body = std::mem::replace(body, #{SdkBody}::taken()); - Ok(#{receiver:W}) + let receiver = #{receiver:W}; + Ok(receiver) """, "SdkBody" to RuntimeType.sdkBody(runtimeConfig), "receiver" to diff --git a/codegen-server-test/build.gradle.kts b/codegen-server-test/build.gradle.kts index dc2b562221b..ab7ecf037d8 100644 --- a/codegen-server-test/build.gradle.kts +++ b/codegen-server-test/build.gradle.kts @@ -116,7 +116,7 @@ val commonCodegenTests = "../codegen-core/common-test-models".let { commonModels ) } // When iterating on protocol tests use this to speed up codegen: -// .filter { it.module == "rpcv2Cbor_extras" || it.module == "rpcv2Cbor_extras_no_initial_response" } +// .filter { it.module == "rpcv2Cbor_extras" || it.module == "rpcv2Cbor_extras_no_initial_response" } val customCodegenTests = "custom-test-models".let { customModels -> listOf( diff --git a/codegen-server-test/integration-tests/Cargo.lock b/codegen-server-test/integration-tests/Cargo.lock index b752b100e92..215926c097a 100644 --- a/codegen-server-test/integration-tests/Cargo.lock +++ b/codegen-server-test/integration-tests/Cargo.lock @@ -76,7 +76,7 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.62.5" +version = "0.62.6" dependencies = [ "aws-smithy-eventstream", "aws-smithy-runtime-api", @@ -120,7 +120,7 @@ dependencies = [ [[package]] name = "aws-smithy-http-server" -version = "0.65.8" +version = "0.65.9" dependencies = [ "aws-smithy-cbor", "aws-smithy-http", @@ -450,8 +450,10 @@ dependencies = [ "hyper-util", "rpcv2cbor_extras", "rpcv2cbor_extras_no_initial_response", + "rstest", "tokio", "tokio-stream", + "tracing", ] [[package]] @@ -475,6 +477,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -482,6 +499,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -490,6 +508,23 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + [[package]] name = "futures-macro" version = "0.3.31" @@ -513,15 +548,25 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", + "futures-io", "futures-macro", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", "slab", @@ -545,6 +590,12 @@ version = "0.32.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e629b9b98ef3dd8afe6ca2bd0f89306cec16d43d907889945bc5d6687f2f13c7" +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + [[package]] name = "h2" version = "0.3.27" @@ -1169,6 +1220,15 @@ dependencies = [ "yansi", ] +[[package]] +name = "proc-macro-crate" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "219cb19e96be00ab2e37d6e299658a0cfa83e52429179969b0f0121b4ac46983" +dependencies = [ + "toml_edit", +] + [[package]] name = "proc-macro2" version = "1.0.101" @@ -1266,6 +1326,12 @@ version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" +[[package]] +name = "relative-path" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" + [[package]] name = "roxmltree" version = "0.14.1" @@ -1314,12 +1380,51 @@ dependencies = [ "tracing", ] +[[package]] +name = "rstest" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a2c585be59b6b5dd66a9d2084aa1d8bd52fbdb806eafdeffb52791147862035" +dependencies = [ + "futures", + "futures-timer", + "rstest_macros", + "rustc_version", +] + +[[package]] +name = "rstest_macros" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "825ea780781b15345a146be27eaefb05085e337e869bff01b4306a4fd4a9ad5a" +dependencies = [ + "cfg-if", + "glob", + "proc-macro-crate", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn", + "unicode-ident", +] + [[package]] name = "rustc-demangle" version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + [[package]] name = "rustversion" version = "1.0.22" @@ -1338,6 +1443,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + [[package]] name = "separator" version = "0.4.1" @@ -1617,6 +1728,36 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml_datetime" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2cdb639ebbc97961c51720f858597f7f24c4fc295327923af55b74c3c724533" +dependencies = [ + "serde_core", +] + +[[package]] +name = "toml_edit" +version = "0.23.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6485ef6d0d9b5d0ec17244ff7eb05310113c3f316f2d14200d4de56b3cb98f8d" +dependencies = [ + "indexmap", + "toml_datetime", + "toml_parser", + "winnow", +] + +[[package]] +name = "toml_parser" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0cbe268d35bdb4bb5a56a2de88d0ad0eb70af5384a99d648cd4b3d04039800e" +dependencies = [ + "winnow", +] + [[package]] name = "tower" version = "0.4.13" @@ -2054,6 +2195,15 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" +[[package]] +name = "winnow" +version = "0.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21a0236b59786fed61e2a80582dd500fe61f18b5dca67a4a067d0bc9039339cf" +dependencies = [ + "memchr", +] + [[package]] name = "wit-bindgen" version = "0.46.0" diff --git a/codegen-server-test/integration-tests/eventstreams/Cargo.toml b/codegen-server-test/integration-tests/eventstreams/Cargo.toml index 811e5f6785c..6773b132bbb 100644 --- a/codegen-server-test/integration-tests/eventstreams/Cargo.toml +++ b/codegen-server-test/integration-tests/eventstreams/Cargo.toml @@ -18,3 +18,5 @@ aws-smithy-runtime = { workspace = true } http-body-util = "0.1.3" hyper-util = { version = "0.1.17", features = ["client-legacy", "tokio", "http2", "http1"] } tokio-stream = "0.1.17" +tracing = "0.1.41" +rstest = "0.23" diff --git a/codegen-server-test/integration-tests/eventstreams/src/lib.rs b/codegen-server-test/integration-tests/eventstreams/src/lib.rs index aa437f5440f..ea7c2ebc304 100644 --- a/codegen-server-test/integration-tests/eventstreams/src/lib.rs +++ b/codegen-server-test/integration-tests/eventstreams/src/lib.rs @@ -57,7 +57,7 @@ impl ManualEventStreamClient { tokio::spawn(async move { while let Some(message) = message_receiver.recv().await { let mut buffer = Vec::new(); - if let Err(_) = write_message_to(&message, &mut buffer) { + if write_message_to(&message, &mut buffer).is_err() { break; } let _ = frame_sender @@ -131,7 +131,7 @@ impl ManualEventStreamClient { self.message_sender .send(message) .await - .map_err(|e| format!("Send failed: {}", e)) + .map_err(|e| format!("Send failed: {e}")) } /// Receives the next response message. diff --git a/codegen-server-test/integration-tests/eventstreams/tests/structured_eventstream_tests.rs b/codegen-server-test/integration-tests/eventstreams/tests/structured_eventstream_tests.rs index 4f52d7141fa..e5bf0d3d4a8 100644 --- a/codegen-server-test/integration-tests/eventstreams/tests/structured_eventstream_tests.rs +++ b/codegen-server-test/integration-tests/eventstreams/tests/structured_eventstream_tests.rs @@ -9,30 +9,34 @@ use aws_smithy_types::event_stream::{Header, HeaderValue, Message}; use bytes::Bytes; use eventstreams::{ManualEventStreamClient, RecvError}; use rpcv2cbor_extras::model::{Event, Events}; +use rpcv2cbor_extras::sigv4_event_stream::SignedEvent; use rpcv2cbor_extras::{error, input, output, RpcV2CborService, RpcV2CborServiceConfig}; use std::sync::{Arc, Mutex}; use tokio::net::TcpListener; #[derive(Debug, Default, Clone)] struct StreamingOperationState { - events: Vec, + events: Vec>, num_calls: usize, + initial_signature: Option>, } #[derive(Debug, Default, Clone)] struct StreamingOperationWithInitialDataState { initial_data: Option, - events: Vec, + events: Vec>, #[allow(dead_code)] num_calls: usize, + initial_signature: Option>, } #[derive(Debug, Default, Clone)] struct StreamingOperationWithOptionalDataState { optional_data: Option, - events: Vec, + events: Vec>, #[allow(dead_code)] num_calls: usize, + initial_signature: Option>, } #[derive(Debug, Default, Clone)] @@ -90,7 +94,7 @@ impl TestServer { Self { addr, state } } - fn streaming_operation_events(&self) -> Vec { + fn streaming_operation_events(&self) -> Vec> { self.state .lock() .unwrap() @@ -99,7 +103,7 @@ impl TestServer { .clone() } - fn streaming_operation_with_initial_data_events(&self) -> Vec { + fn streaming_operation_with_initial_data_events(&self) -> Vec> { self.state .lock() .unwrap() @@ -117,7 +121,7 @@ impl TestServer { .clone() } - fn streaming_operation_with_optional_data_events(&self) -> Vec { + fn streaming_operation_with_optional_data_events(&self) -> Vec> { self.state .lock() .unwrap() @@ -134,6 +138,24 @@ impl TestServer { .optional_data .clone() } + + fn initial_signature(&self) -> Option> { + self.state + .lock() + .unwrap() + .streaming_operation_with_initial_data + .initial_signature + .clone() + } + + fn streaming_operation_initial_signature(&self) -> Option> { + self.state + .lock() + .unwrap() + .streaming_operation + .initial_signature + .clone() + } } async fn streaming_operation_handler( @@ -141,18 +163,26 @@ async fn streaming_operation_handler( state: Arc>, ) -> Result { state.lock().unwrap().streaming_operation.num_calls += 1; - let ev = input.events.recv().await; + state.lock().unwrap().streaming_operation.initial_signature = input + .events + .initial_signature() + .map(|s| s.chunk_signature.to_vec()); - if let Ok(Some(signed_event)) = &ev { - // Extract the actual event from the SignedEvent wrapper - let actual_event = &signed_event.message; - state - .lock() - .unwrap() - .streaming_operation - .events - .push(actual_event.clone()); - } + let state_clone = state.clone(); + tokio::spawn(async move { + while let Ok(Some(signed_event)) = input.events.recv().await { + tracing::debug!( + "streaming_operation received event: {:?}", + signed_event.message + ); + state_clone + .lock() + .unwrap() + .streaming_operation + .events + .push(signed_event); + } + }); Ok(output::StreamingOperationOutput::builder() .events(EventStreamSender::once(Ok(Events::A(Event {})))) @@ -173,19 +203,30 @@ async fn streaming_operation_with_initial_data_handler( .unwrap() .streaming_operation_with_initial_data .initial_data = Some(input.initial_data); + state + .lock() + .unwrap() + .streaming_operation_with_initial_data + .initial_signature = input + .events + .initial_signature() + .map(|s| s.chunk_signature.to_vec()); - let ev = input.events.recv().await; - - if let Ok(Some(signed_event)) = &ev { - // Extract the actual event from the SignedEvent wrapper - let actual_event = &signed_event.message; - state - .lock() - .unwrap() - .streaming_operation_with_initial_data - .events - .push(actual_event.clone()); - } + let state_clone = state.clone(); + tokio::spawn(async move { + while let Ok(Some(signed_event)) = input.events.recv().await { + tracing::debug!( + "streaming_operation_with_initial_data received event: {:?}", + signed_event.message + ); + state_clone + .lock() + .unwrap() + .streaming_operation_with_initial_data + .events + .push(signed_event); + } + }); Ok(output::StreamingOperationWithInitialDataOutput::builder() .events(EventStreamSender::once(Ok(Events::A(Event {})))) @@ -200,7 +241,14 @@ async fn streaming_operation_with_initial_response_handler( output::StreamingOperationWithInitialResponseOutput, error::StreamingOperationWithInitialResponseError, > { - let _ev = input.events.recv().await; + tokio::spawn(async move { + while let Ok(Some(event)) = input.events.recv().await { + tracing::debug!( + "streaming_operation_with_initial_response received event: {:?}", + event + ); + } + }); Ok( output::StreamingOperationWithInitialResponseOutput::builder() @@ -224,17 +272,30 @@ async fn streaming_operation_with_optional_data_handler( .unwrap() .streaming_operation_with_optional_data .optional_data = input.optional_data; + state + .lock() + .unwrap() + .streaming_operation_with_optional_data + .initial_signature = input + .events + .initial_signature() + .map(|s| s.chunk_signature.to_vec()); - let ev = input.events.recv().await; - - if let Ok(Some(event)) = &ev { - state - .lock() - .unwrap() - .streaming_operation_with_optional_data - .events - .push(event.message.clone()); - } + let state_clone = state.clone(); + tokio::spawn(async move { + while let Ok(Some(event)) = input.events.recv().await { + tracing::debug!( + "streaming_operation_with_optional_data received event: {:?}", + event + ); + state_clone + .lock() + .unwrap() + .streaming_operation_with_optional_data + .events + .push(event); + } + }); Ok(output::StreamingOperationWithOptionalDataOutput::builder() .optional_response_data(Some("optional response".to_string())) @@ -255,7 +316,7 @@ struct TestHarness { impl TestHarness { async fn new(operation: &str) -> Self { let server = TestServer::start().await; - let path = format!("/service/RpcV2CborService/operation/{}", operation); + let path = format!("/service/RpcV2CborService/operation/{operation}"); let client = ManualEventStreamClient::connect_to_service( server.addr, &path, @@ -271,11 +332,6 @@ impl TestHarness { } } - async fn send_initial_request(&mut self) { - let msg = build_initial_request(); - self.client.send(msg).await.ok(); - } - async fn send_initial_data(&mut self, data: &str) { let msg = build_initial_data_message(data); self.client.send(msg).await.ok(); @@ -352,39 +408,41 @@ fn build_event(event_type: &str) -> Message { Message::new_from_parts(headers, empty_cbor) } -fn build_sigv4_signed_event(event_type: &str) -> Message { +fn sign_message(inner_message: Message, signature: &[u8], timestamp_secs: i64) -> Message { use aws_smithy_eventstream::frame::write_message_to; - use std::time::{SystemTime, UNIX_EPOCH}; - // Build the inner event message - let inner_event = build_event(event_type); - - // Serialize the inner message to bytes let mut inner_bytes = Vec::new(); - write_message_to(&inner_event, &mut inner_bytes).unwrap(); - - // Create the SigV4 envelope with signature headers - let timestamp = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(); + write_message_to(&inner_message, &mut inner_bytes).unwrap(); let headers = vec![ Header::new( ":chunk-signature", - HeaderValue::ByteArray(Bytes::from( - "example298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", - )), + HeaderValue::ByteArray(Bytes::from(signature.to_vec())), ), Header::new( ":date", - HeaderValue::Timestamp(aws_smithy_types::DateTime::from_secs(timestamp as i64)), + HeaderValue::Timestamp(aws_smithy_types::DateTime::from_secs(timestamp_secs)), ), ]; Message::new_from_parts(headers, Bytes::from(inner_bytes)) } +fn build_sigv4_signed_event_with_signature(event_type: &str, signature: &[u8]) -> Message { + use std::time::{SystemTime, UNIX_EPOCH}; + + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + sign_message(build_event(event_type), signature, timestamp as i64) +} + +fn build_sigv4_signed_initial_data(data: &str, signature: &[u8], timestamp_secs: i64) -> Message { + sign_message(build_initial_data_message(data), signature, timestamp_secs) +} + fn get_event_type(msg: &Message) -> &str { msg.headers() .iter() @@ -396,45 +454,6 @@ fn get_event_type(msg: &Message) -> &str { .as_str() } -#[tokio::test] -async fn test_streaming_operation_with_initial_request() { - let mut harness = TestHarness::new("StreamingOperation").await; - - // if we send an initial request it should work - harness.send_initial_request().await; - harness.send_event("A").await; - - let resp = harness.expect_message().await; - assert_eq!(get_event_type(&resp), "A"); - - // Check that initial-response was received - assert!(harness.initial_response.is_some()); - assert_eq!( - get_event_type(harness.initial_response.as_ref().unwrap()), - "initial-response" - ); - - assert_eq!( - harness.server.streaming_operation_events(), - vec![Events::A(Event {})] - ); -} - -#[tokio::test] -async fn test_streaming_operation_without_initial_request() { - let mut harness = TestHarness::new("StreamingOperation").await; - - // BUT: if we don't send an initial request, it should also work - harness.send_event("A").await; - - let resp = harness.expect_message().await; - assert_eq!(get_event_type(&resp), "A"); - assert_eq!( - harness.server.streaming_operation_events(), - vec![Events::A(Event {})] - ); -} - #[tokio::test] async fn test_streaming_operation_with_initial_data() { let mut harness = TestHarness::new("StreamingOperationWithInitialData").await; @@ -447,7 +466,10 @@ async fn test_streaming_operation_with_initial_data() { assert_eq!( harness .server - .streaming_operation_with_initial_data_events(), + .streaming_operation_with_initial_data_events() + .into_iter() + .map(|e| e.message) + .collect::>(), vec![Events::A(Event {})] ); // verify that we parsed the initial data properly @@ -471,29 +493,14 @@ async fn test_streaming_operation_with_initial_data_missing() { assert_eq!( harness .server - .streaming_operation_with_initial_data_events(), + .streaming_operation_with_initial_data_events() + .into_iter() + .map(|e| e.message) + .collect::>(), vec![] ); } -/// Test that the server can handle SigV4 signed event stream messages. -/// The client wraps the actual event in a SigV4 envelope with signature headers. -#[tokio::test] -async fn test_sigv4_signed_event_stream() { - let mut harness = TestHarness::new("StreamingOperation").await; - - // Send a SigV4 signed event - the inner message is wrapped in an envelope - let signed_event = build_sigv4_signed_event("A"); - harness.client.send(signed_event).await.unwrap(); - - let resp = harness.expect_message().await; - assert_eq!(get_event_type(&resp), "A"); - assert_eq!( - harness.server.streaming_operation_events(), - vec![Events::A(Event {})] - ); -} - /// Test that when alwaysSendEventStreamInitialResponse is disabled, no initial-response is sent #[tokio::test] async fn test_server_no_initial_response_when_disabled() { @@ -612,9 +619,234 @@ async fn test_streaming_operation_with_optional_data() { assert_eq!( harness .server - .streaming_operation_with_optional_data_events(), + .streaming_operation_with_optional_data_events() + .into_iter() + .map(|e| e.message) + .collect::>(), vec![Events::A(Event {})] ); // Verify optional data was not provided assert_eq!(harness.server.optional_data(), None); } + +/// Test that SigV4-framed initial-request messages are properly handled. +/// This verifies the fix for issue #4397 where try_recv_initial_request +/// can now see inside the SigV4 envelope to detect the initial-request event type. +#[tokio::test] +async fn test_sigv4_framed_initial_request_with_data() { + let _logs = show_filtered_test_logs( + "aws_smithy_http_server=trace,hyper_util=debug,rpcv2cbor_extras=trace", + ); + let mut harness = TestHarness::new("StreamingOperationWithInitialData").await; + + // Send a SigV4-framed initial-request with data + let signed_initial_request = build_sigv4_signed_initial_data( + "test-data", + b"example298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + 1700000000, + ); + harness.client.send(signed_initial_request).await.unwrap(); + + harness.send_event("A").await; + + // The server should now properly extract the initial-request from the SigV4 envelope + let resp = harness.expect_message().await; + assert_eq!(get_event_type(&resp), "A"); + + // Verify the server received and parsed the initial data from inside the SigV4 envelope + assert_eq!(harness.server.initial_data(), Some("test-data".to_string())); + assert_eq!( + harness.server.initial_signature(), + Some(b"example298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".to_vec()) + ); +} + +#[derive(Debug, Clone, Copy)] +enum InitialMessage { + None, + Unsigned, + Signed, +} + +#[derive(Debug, Clone)] +struct EventStreamTestCase { + initial: InitialMessage, + events_signed: Vec, +} + +/// Comprehensive test matrix for SigV4 event stream combinations +#[rstest::rstest] +#[case::no_initial_unsigned_events(EventStreamTestCase { initial: InitialMessage::None, events_signed: vec![false, false] })] +#[case::no_initial_signed_events(EventStreamTestCase { initial: InitialMessage::None, events_signed: vec![true, true] })] +#[case::no_initial_mixed_events(EventStreamTestCase { initial: InitialMessage::None, events_signed: vec![false, true] })] +#[case::unsigned_initial_unsigned_events(EventStreamTestCase { initial: InitialMessage::Unsigned, events_signed: vec![false, false] })] +#[case::unsigned_initial_signed_events(EventStreamTestCase { initial: InitialMessage::Unsigned, events_signed: vec![true, true] })] +#[case::unsigned_initial_mixed_events(EventStreamTestCase { initial: InitialMessage::Unsigned, events_signed: vec![false, true] })] +#[case::signed_initial_unsigned_events(EventStreamTestCase { initial: InitialMessage::Signed, events_signed: vec![false, false] })] +#[case::signed_initial_signed_events(EventStreamTestCase { initial: InitialMessage::Signed, events_signed: vec![true, true] })] +#[case::signed_initial_mixed_events(EventStreamTestCase { initial: InitialMessage::Signed, events_signed: vec![false, true] })] +#[case::no_events(EventStreamTestCase { initial: InitialMessage::None, events_signed: vec![] })] +#[case::many_signed_events(EventStreamTestCase { initial: InitialMessage::Signed, events_signed: vec![true; 100] })] +#[case::many_unsigned_events(EventStreamTestCase { initial: InitialMessage::None, events_signed: vec![false; 100] })] +#[tokio::test] +async fn test_sigv4_event_stream_matrix(#[case] test_case: EventStreamTestCase) { + let mut harness = TestHarness::new("StreamingOperation").await; + + // Send initial message if specified + match test_case.initial { + InitialMessage::None => {} + InitialMessage::Unsigned => { + harness.client.send(build_initial_request()).await.unwrap(); + } + InitialMessage::Signed => { + let signed_initial = sign_message(build_initial_request(), b"initial-sig", 1700000000); + harness.client.send(signed_initial).await.unwrap(); + } + } + + // Send events + for (i, &signed) in test_case.events_signed.iter().enumerate() { + let event_type = if i % 2 == 0 { "A" } else { "B" }; + if signed { + let sig = format!("sig-event-{i}"); + let signed_event = build_sigv4_signed_event_with_signature(event_type, sig.as_bytes()); + harness.client.send(signed_event).await.unwrap(); + } else { + harness.send_event(event_type).await; + } + } + + // Receive response (only if we sent events) + if !test_case.events_signed.is_empty() { + let resp = harness.expect_message().await; + assert_eq!(get_event_type(&resp), "A"); + } + + // Verify events + let events = harness.server.streaming_operation_events(); + assert_eq!(events.len(), test_case.events_signed.len()); + + for (i, &signed) in test_case.events_signed.iter().enumerate() { + let expected_event = if i % 2 == 0 { + Events::A(Event {}) + } else { + Events::B(Event {}) + }; + assert_eq!(events[i].message, expected_event); + + if signed { + assert!( + events[i].signature.is_some(), + "Event {i} should have signature" + ); + let expected_sig = format!("sig-event-{i}"); + assert_eq!( + events[i].signature.as_ref().unwrap().chunk_signature, + expected_sig.as_bytes() + ); + } else { + assert!( + events[i].signature.is_none(), + "Event {i} should not have signature" + ); + } + } + + // Verify initial signature + match test_case.initial { + InitialMessage::Signed => { + assert_eq!( + harness.server.streaming_operation_initial_signature(), + Some(b"initial-sig".to_vec()) + ); + } + InitialMessage::None | InitialMessage::Unsigned => { + assert_eq!(harness.server.streaming_operation_initial_signature(), None); + } + } +} + +/// Test signed initial data with signed events +#[tokio::test] +async fn test_sigv4_signed_initial_data_with_signed_events() { + let mut harness = TestHarness::new("StreamingOperationWithInitialData").await; + + // Send signed initial data + let signed_initial = + build_sigv4_signed_initial_data("test-data", b"sig-initial-data", 1700000000); + harness.client.send(signed_initial).await.unwrap(); + + // Send signed events + let signed_event_a = build_sigv4_signed_event_with_signature("A", b"sig-event-A"); + harness.client.send(signed_event_a).await.unwrap(); + + let signed_event_b = build_sigv4_signed_event_with_signature("B", b"sig-event-B"); + harness.client.send(signed_event_b).await.unwrap(); + + let resp = harness.expect_message().await; + assert_eq!(get_event_type(&resp), "A"); + + // Verify initial data was received + assert_eq!(harness.server.initial_data(), Some("test-data".to_string())); + + // Verify initial signature + assert_eq!( + harness.server.initial_signature(), + Some(b"sig-initial-data".to_vec()) + ); + + // Verify events with signatures + let events = harness + .server + .streaming_operation_with_initial_data_events(); + assert_eq!(events.len(), 2); + + assert_eq!(events[0].message, Events::A(Event {})); + assert_eq!( + events[0].signature.as_ref().unwrap().chunk_signature, + b"sig-event-A" + ); + + assert_eq!(events[1].message, Events::B(Event {})); + assert_eq!( + events[1].signature.as_ref().unwrap().chunk_signature, + b"sig-event-B" + ); +} + +/// Test that timestamps are preserved in signatures +#[tokio::test] +async fn test_sigv4_timestamp_preservation() { + let mut harness = TestHarness::new("StreamingOperation").await; + + // Send events with specific timestamps + let timestamp1 = 1700000000i64; + let timestamp2 = 1700000100i64; + + let event1 = sign_message(build_event("A"), b"sig-1", timestamp1); + harness.client.send(event1).await.unwrap(); + + let event2 = sign_message(build_event("B"), b"sig-2", timestamp2); + harness.client.send(event2).await.unwrap(); + + let resp = harness.expect_message().await; + assert_eq!(get_event_type(&resp), "A"); + + let events = harness.server.streaming_operation_events(); + assert_eq!(events.len(), 2); + + // Verify timestamps are preserved + use std::time::UNIX_EPOCH; + + let expected_time1 = UNIX_EPOCH + std::time::Duration::from_secs(timestamp1 as u64); + assert_eq!( + events[0].signature.as_ref().unwrap().timestamp, + expected_time1 + ); + + let expected_time2 = UNIX_EPOCH + std::time::Duration::from_secs(timestamp2 as u64); + assert_eq!( + events[1].signature.as_ref().unwrap().timestamp, + expected_time2 + ); +} diff --git a/codegen-server/codegen-server-typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerCodegenVisitor.kt b/codegen-server/codegen-server-typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerCodegenVisitor.kt index 53763211c73..dea89fbc07e 100644 --- a/codegen-server/codegen-server-typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerCodegenVisitor.kt +++ b/codegen-server/codegen-server-typescript/src/main/kotlin/software/amazon/smithy/rust/codegen/server/typescript/smithy/TsServerCodegenVisitor.kt @@ -62,7 +62,7 @@ class TsServerCodegenVisitor( ServerProtocolLoader( codegenDecorator.protocols( service.id, - ServerProtocolLoader.defaultProtocols(), + ServerProtocolLoader.DefaultProtocols, ), ) .protocolFor(context.model, service) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index 7fd8d76b32c..121001a0da5 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -139,12 +139,7 @@ open class ServerCodegenVisitor( ServerProtocolLoader( codegenDecorator.protocols( service.id, - ServerProtocolLoader.defaultProtocols { it -> - codegenDecorator.httpCustomizations( - serverSymbolProviders.symbolProvider, - it, - ) - }, + ServerProtocolLoader.DefaultProtocols, ), ) .protocolFor(context.model, service) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt index f6a617bd38a..86c53d1ceda 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt @@ -34,6 +34,7 @@ import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.orNull import java.util.logging.Level +import java.util.stream.Collectors private sealed class UnsupportedConstraintMessageKind { private val constraintTraitsUberIssue = "https://github.com/smithy-lang/smithy-rs/issues/1401" @@ -330,7 +331,7 @@ fun validateModelHasAtMostOneValidationException( model .shapes() .filter { it.hasTrait(ValidationExceptionTrait.ID) && it.isReachableFromOperationErrors(model) } - .toList() + .collect(Collectors.toList()) val messages = mutableListOf() diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamDecorator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamDecorator.kt index 0cf966749e4..e1ba675183a 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamDecorator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamDecorator.kt @@ -11,18 +11,8 @@ import software.amazon.smithy.model.knowledge.ServiceIndex import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape -import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.rust.codegen.core.rustlang.RustType -import software.amazon.smithy.rust.codegen.core.rustlang.Writable -import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization -import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingSection -import software.amazon.smithy.rust.codegen.core.smithy.mapRustType -import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.isEventStream @@ -36,59 +26,28 @@ class SigV4EventStreamDecorator : ServerCodegenDecorator { override val name: String = "SigV4EventStreamDecorator" override val order: Byte = 0 - override fun httpCustomizations( - symbolProvider: RustSymbolProvider, - protocol: ShapeId, - ): List { - return listOf(SigV4EventStreamCustomization(symbolProvider)) - } - override fun symbolProvider(base: RustSymbolProvider): RustSymbolProvider { - // We need access to the service shape to check for SigV4 trait, but the base interface doesn't provide it. - // For now, we'll wrap all event streams and let the runtime code handle the detection. - return SigV4EventStreamSymbolProvider(base) + if (base.usesSigAuth()) { + return SigV4EventStreamSymbolProvider(base) + } else { + return base + } } } internal fun RustSymbolProvider.usesSigAuth(): Boolean = ServiceIndex.of(model).getAuthSchemes(moduleProviderContext.serviceShape!!).containsKey(SigV4Trait.ID) -// Goes from `T` to `SignedEvent` -fun wrapInSignedEvent( - inner: Symbol, - runtimeConfig: RuntimeConfig, -) = inner.mapRustType { - RustType.Application( - SigV4EventStreamSupportStructures.signedEvent(runtimeConfig).toSymbol().rustType(), - listOf(inner.rustType()), - ) -} - -// Goes from `E` to `SignedEventError` -fun wrapInSignedEventError( - inner: Symbol, - runtimeConfig: RuntimeConfig, -) = inner.mapRustType { - RustType.Application( - SigV4EventStreamSupportStructures.signedEventError(runtimeConfig).toSymbol().rustType(), - listOf(inner.rustType()), - ) -} - /** * Symbol provider wrapper that modifies event stream types to support SigV4 signed messages. */ class SigV4EventStreamSymbolProvider( base: RustSymbolProvider, ) : WrappingSymbolProvider(base) { - private val serviceIsSigv4 = base.usesSigAuth() private val runtimeConfig = base.config.runtimeConfig override fun toSymbol(shape: Shape): Symbol { val baseSymbol = super.toSymbol(shape) - if (!serviceIsSigv4) { - return baseSymbol - } // We only want to wrap with Event Stream types when dealing with member shapes if (shape is MemberShape && shape.isEventStream(model)) { // Determine if the member has a container that is a synthetic input or output @@ -109,26 +68,3 @@ class SigV4EventStreamSymbolProvider( return baseSymbol } } - -class SigV4EventStreamCustomization(private val symbolProvider: RustSymbolProvider) : HttpBindingCustomization() { - override fun section(section: HttpBindingSection): Writable = - writable { - when (section) { - is HttpBindingSection.BeforeCreatingEventStreamReceiver -> { - // Check if this service uses SigV4 auth - if (symbolProvider.usesSigAuth()) { - val codegenScope = - SigV4EventStreamSupportStructures.codegenScope(symbolProvider.config.runtimeConfig) - rustTemplate( - """ - let ${section.unmarshallerVariableName} = #{SigV4Unmarshaller}::new(${section.unmarshallerVariableName}); - """, - *codegenScope, - ) - } - } - - else -> {} - } - } -} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamSupportStructures.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamSupportStructures.kt index 81f67f1dd41..00ece53a410 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamSupportStructures.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamSupportStructures.kt @@ -17,7 +17,11 @@ import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.PANIC object SigV4EventStreamSupportStructures { - private val supportModule = RustModule.private("sigv4_event_stream") + internal val supportModule = + RustModule.public( + "sigv4_event_stream", + documentationOverride = "Support structures for SigV4 signed event streams", + ) fun codegenScope(runtimeConfig: RuntimeConfig) = arrayOf( @@ -25,42 +29,30 @@ object SigV4EventStreamSupportStructures { "ExtractionError" to extractionError(runtimeConfig), "SignedEventError" to signedEventError(runtimeConfig), "SignedEvent" to signedEvent(runtimeConfig), + "SigV4Receiver" to sigV4Receiver(runtimeConfig), "SigV4Unmarshaller" to sigV4Unmarshaller(runtimeConfig), "extract_signed_message" to extractSignedMessage(runtimeConfig), ) /** - * Wraps an event stream Receiver type to handle SigV4 signed messages. - * Transforms: Receiver -> Receiver, SignedEventError> + * Wraps an event stream Receiver type with SigV4Receiver. + * Transforms: Receiver -> SigV4Receiver */ fun wrapInEventStreamSigV4( symbol: Symbol, runtimeConfig: RuntimeConfig, ): Symbol { - val signedEvent = signedEvent(runtimeConfig) - val signedEventError = signedEventError(runtimeConfig) - return symbol.mapRustType(signedEvent, signedEventError) { rustType -> + val sigV4Receiver = sigV4Receiver(runtimeConfig) + return symbol.mapRustType(sigV4Receiver) { rustType -> // Expect Application(Receiver, [T, E]) if (rustType is RustType.Application && rustType.name == "Receiver" && rustType.args.size == 2) { val eventType = rustType.args[0] val errorType = rustType.args[1] - // Create SignedEvent and SignedEventError - val wrappedEventType = - RustType.Application( - signedEvent.toSymbol().rustType(), - listOf(eventType), - ) - val wrappedErrorType = - RustType.Application( - signedEventError.toSymbol().rustType(), - listOf(errorType), - ) - - // Create new Receiver, SignedEventError> + // Create SigV4Receiver RustType.Application( - rustType.type, - listOf(wrappedEventType, wrappedErrorType), + sigV4Receiver.toSymbol().rustType(), + listOf(eventType, errorType), ) } else { PANIC("Called wrap in EventStreamSigV4 on ${symbol.rustType()} which was not an event stream receiver") @@ -74,7 +66,7 @@ object SigV4EventStreamSupportStructures { """ /// Information extracted from a signed event stream message ##[non_exhaustive] - ##[derive(Debug, Clone)] + ##[derive(Debug, Clone, PartialEq)] pub struct SignatureInfo { /// The chunk signature bytes from the `:chunk-signature` header pub chunk_signature: Vec, @@ -103,8 +95,35 @@ object SigV4EventStreamSupportStructures { ##[non_exhaustive] InvalidTimestamp, } + + impl #{Display} for ExtractionError { + fn fmt(&self, f: &mut #{Formatter}<'_>) -> #{fmt_Result} { + match self { + ExtractionError::InvalidPayload { error } => { + write!(f, "invalid payload: {}", error) + } + ExtractionError::InvalidTimestamp => { + write!(f, "invalid or missing timestamp header") + } + } + } + } + + impl #{Error} for ExtractionError { + fn source(&self) -> #{Option}<&(dyn #{Error} + 'static)> { + match self { + ExtractionError::InvalidPayload { error } => #{Some}(error), + ExtractionError::InvalidTimestamp => #{None}, + } + } + } """, "EventStreamError" to CargoDependency.smithyEventStream(runtimeConfig).toType().resolve("error::Error"), + "Display" to RuntimeType.Display, + "Formatter" to RuntimeType.std.resolve("fmt::Formatter"), + "fmt_Result" to RuntimeType.std.resolve("fmt::Result"), + "Error" to RuntimeType.StdError, + *RuntimeType.preludeScope, ) } @@ -136,7 +155,7 @@ object SigV4EventStreamSupportStructures { rustTemplate( """ /// Wrapper for event stream messages that may be signed - ##[derive(Debug)] + ##[derive(Debug, Clone)] pub struct SignedEvent { /// The actual event message pub message: T, @@ -175,41 +194,41 @@ object SigV4EventStreamSupportStructures { fn unmarshall(&self, message: &#{Message}) -> #{Result}<#{UnmarshalledMessage}, #{EventStreamError}> { // First, try to extract the signed message match #{extract_signed_message}(message) { - Ok(MaybeSignedMessage::Signed { message: inner_message, signature }) => { + #{Ok}(MaybeSignedMessage::Signed { message: inner_message, signature }) => { // Process the inner message with the base unmarshaller match self.inner.unmarshall(&inner_message) { - Ok(unmarshalled) => match unmarshalled { + #{Ok}(unmarshalled) => match unmarshalled { #{UnmarshalledMessage}::Event(event) => { - Ok(#{UnmarshalledMessage}::Event(#{SignedEvent} { + #{Ok}(#{UnmarshalledMessage}::Event(#{SignedEvent} { message: event, - signature: Some(signature), + signature: #{Some}(signature), })) } #{UnmarshalledMessage}::Error(err) => { - Ok(#{UnmarshalledMessage}::Error(#{SignedEventError}::Event(err))) + #{Ok}(#{UnmarshalledMessage}::Error(#{SignedEventError}::Event(err))) } }, - Err(err) => Err(err), + #{Err}(err) => #{Err}(err), } } - Ok(MaybeSignedMessage::Unsigned) => { + #{Ok}(MaybeSignedMessage::Unsigned) => { // Process unsigned message directly match self.inner.unmarshall(message) { - Ok(unmarshalled) => match unmarshalled { + #{Ok}(unmarshalled) => match unmarshalled { #{UnmarshalledMessage}::Event(event) => { - Ok(#{UnmarshalledMessage}::Event(#{SignedEvent} { + #{Ok}(#{UnmarshalledMessage}::Event(#{SignedEvent} { message: event, - signature: None, + signature: #{None}, })) } #{UnmarshalledMessage}::Error(err) => { - Ok(#{UnmarshalledMessage}::Error(#{SignedEventError}::Event(err))) + #{Ok}(#{UnmarshalledMessage}::Error(#{SignedEventError}::Event(err))) } }, - Err(err) => Err(err), + #{Err}(err) => #{Err}(err), } } - Err(extraction_err) => Ok(#{UnmarshalledMessage}::Error(#{SignedEventError}::InvalidSignedEvent(extraction_err))), + #{Err}(extraction_err) => #{Ok}(#{UnmarshalledMessage}::Error(#{SignedEventError}::InvalidSignedEvent(extraction_err))), } } } @@ -229,6 +248,95 @@ object SigV4EventStreamSupportStructures { ) } + private fun sigV4Receiver(runtimeConfig: RuntimeConfig): RuntimeType = + RuntimeType.forInlineFun("SigV4Receiver", supportModule) { + rustTemplate( + """ + /// Receiver wrapper that handles SigV4 signed event stream messages + ##[derive(Debug)] + pub struct SigV4Receiver { + inner: #{Receiver}<#{SignedEvent}, #{SignedEventError}>, + initial_signature: #{Option}<#{SignatureInfo}>, + } + + impl SigV4Receiver { + pub fn new( + unmarshaller: impl #{UnmarshallMessage} + #{Send} + #{Sync} + 'static, + body: #{SdkBody}, + ) -> Self { + let sigv4_unmarshaller = #{SigV4Unmarshaller}::new(unmarshaller); + Self { + inner: #{Receiver}::new(sigv4_unmarshaller, body), + initial_signature: #{None}, + } + } + + /// Get the signature from the initial message, if it was signed + pub fn initial_signature(&self) -> #{Option}<&#{SignatureInfo}> { + self.initial_signature.as_ref() + } + + /// Try to receive an initial message of the given type. + /// Handles SigV4-wrapped messages by extracting the inner message first. + pub async fn try_recv_initial( + &mut self, + message_type: #{event_stream}::InitialMessageType, + ) -> #{Result}<#{Option}<#{Message}>, #{SdkError}<#{SignedEventError}, #{RawMessage}>> + where + E: std::error::Error + 'static, + { + let result = self + .inner + .try_recv_initial_with_preprocessor(message_type, |message| { + match #{extract_signed_message}(&message) { + #{Ok}(MaybeSignedMessage::Signed { message: inner, signature }) => { + #{Ok}((inner, #{Some}(signature))) + } + #{Ok}(MaybeSignedMessage::Unsigned) => #{Ok}((message, #{None})), + #{Err}(err) => #{Err}(#{ResponseError}::builder().raw(#{RawMessage}::Decoded(message)).source(err).build()), + } + }) + .await?; + match result { + #{Some}((message, signature)) => { + self.initial_signature = signature; + #{Ok}(#{Some}(message)) + } + #{None} => #{Ok}(#{None}), + } + } + + /// Receive the next event from the stream + /// The SigV4Unmarshaller handles unwrapping, so we just pass through + pub async fn recv(&mut self) -> #{Result}<#{Option}<#{SignedEvent}>, #{SdkError}<#{SignedEventError}, #{RawMessage}>> + where + E: std::error::Error + 'static, + { + self.inner.recv().await + } + } + """, + "Receiver" to RuntimeType.eventStreamReceiver(runtimeConfig), + "SigV4Unmarshaller" to sigV4Unmarshaller(runtimeConfig), + "event_stream" to RuntimeType.smithyHttp(runtimeConfig).resolve("event_stream"), + "SdkBody" to RuntimeType.sdkBody(runtimeConfig), + "Message" to CargoDependency.smithyTypes(runtimeConfig).toType().resolve("event_stream::Message"), + "RawMessage" to CargoDependency.smithyTypes(runtimeConfig).toType().resolve("event_stream::RawMessage"), + "SdkError" to RuntimeType.sdkError(runtimeConfig), + "ResponseError" to + RuntimeType.smithyRuntimeApiClient(runtimeConfig) + .resolve("client::result::ResponseError"), + "UnmarshallMessage" to + CargoDependency.smithyEventStream(runtimeConfig).toType() + .resolve("frame::UnmarshallMessage"), + "SignedEvent" to signedEvent(runtimeConfig), + "SignedEventError" to signedEventError(runtimeConfig), + "SignatureInfo" to signatureInfo(), + "extract_signed_message" to extractSignedMessage(runtimeConfig), + *RuntimeType.preludeScope, + ) + } + private fun extractSignedMessage(runtimeConfig: RuntimeConfig): RuntimeType = RuntimeType.forInlineFun("extract_signed_message", supportModule) { rustTemplate( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/UserProvidedValidationExceptionDecorator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/UserProvidedValidationExceptionDecorator.kt index 8df32515289..09901d013ac 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/UserProvidedValidationExceptionDecorator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/UserProvidedValidationExceptionDecorator.kt @@ -55,6 +55,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser import software.amazon.smithy.rust.codegen.server.smithy.util.isValidationFieldName import software.amazon.smithy.rust.codegen.server.smithy.util.isValidationMessage import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage +import java.util.stream.Collectors /** * Decorator for user provided validation exception codegen @@ -84,7 +85,7 @@ class UserProvidedValidationExceptionDecorator : ServerCodegenDecorator { internal fun firstStructureShapeWithValidationExceptionTrait(model: Model): StructureShape? = model .shapes(StructureShape::class.java) - .toList() + .collect(Collectors.toList()) // Defining multiple validation exceptions is unsupported. See `ValidateUnsupportedConstraints` .firstOrNull({ it.hasTrait(ValidationExceptionTrait.ID) }) @@ -328,9 +329,7 @@ class UserProvidedValidationExceptionConversionGenerator( "FieldAssignments" to fieldAssignments( "path.clone()", - """format!(${ - lengthTrait.validationErrorMessage().dq() - }, length, &path)""", + "format!(${lengthTrait.validationErrorMessage().dq()}, length, &path)", ), ) } @@ -351,9 +350,9 @@ class UserProvidedValidationExceptionConversionGenerator( "FieldAssignments" to fieldAssignments( "path.clone()", - """format!(${ + "format!(${ patternTrait.validationErrorMessage().dq() - }, &path, ${patternTrait.pattern.toString().dq()})""", + }, &path, ${patternTrait.pattern.toString().dq()})", ), ) } @@ -394,9 +393,9 @@ class UserProvidedValidationExceptionConversionGenerator( "FieldAssignments" to fieldAssignments( "path.clone()", - """format!(${ + "format!(${ blobLength.lengthTrait.validationErrorMessage().dq() - }, length, &path)""", + }, length, &path)", ), ) } @@ -519,7 +518,7 @@ class UserProvidedValidationExceptionConversionGenerator( ConstraintViolation::${it.name()} => #{ValidationExceptionField} { #{FieldAssignments} }, - """.trimIndent(), + """, *codegenScope, "FieldAssignments" to fieldAssignments( @@ -569,10 +568,10 @@ class UserProvidedValidationExceptionConversionGenerator( "FieldAssignments" to fieldAssignments( "path.clone()", - """format!(${ + "format!(${ collectionTraitInfo.lengthTrait.validationErrorMessage() .dq() - }, length, &path)""", + }, length, &path)", ), ) } @@ -588,10 +587,15 @@ class UserProvidedValidationExceptionConversionGenerator( "FieldAssignments" to fieldAssignments( "path.clone()", - """format!(${ - collectionTraitInfo.uniqueItemsTrait.validationErrorMessage() - .dq() - }, &duplicate_indices, &path)""", + """ + format!( + ${ + collectionTraitInfo.uniqueItemsTrait.validationErrorMessage().dq() + }, + &duplicate_indices, + &path + ) + """, ), ) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt index d82d3ad65fb..6e23bbd6787 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt @@ -10,10 +10,8 @@ import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.customize.CombinedCoreCodegenDecorator import software.amazon.smithy.rust.codegen.core.smithy.customize.CoreCodegenDecorator -import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings @@ -34,11 +32,6 @@ interface ServerCodegenDecorator : CoreCodegenDecorator = emptyList() - fun validationExceptionConversion(codegenContext: ServerCodegenContext): ValidationExceptionConversionGenerator? = null @@ -102,11 +95,6 @@ class CombinedServerCodegenDecorator(decorators: List) : decorator.protocols(serviceId, protocolMap) } - override fun httpCustomizations( - symbolProvider: RustSymbolProvider, - protocol: ShapeId, - ): List = orderedDecorators.flatMap { it.httpCustomizations(symbolProvider, protocol) } - override fun validationExceptionConversion( codegenContext: ServerCodegenContext, ): ValidationExceptionConversionGenerator = diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt index 98434672d52..53a3f406891 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt @@ -74,7 +74,6 @@ class ServerResponseBeforeIteratingOverMapBoundWithHttpPrefixHeadersUnwrapConstr is HttpBindingSection.BeforeRenderingHeaderValue, is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders, is HttpBindingSection.AfterDeserializingIntoADateTimeOfHttpHeaders, - is HttpBindingSection.BeforeCreatingEventStreamReceiver, -> emptySection } } @@ -109,7 +108,6 @@ class ServerResponseBeforeRenderingHeadersHttpBindingCustomization(val codegenCo is HttpBindingSection.BeforeIteratingOverMapShapeBoundWithHttpPrefixHeaders, is HttpBindingSection.AfterDeserializingIntoAHashMapOfHttpPrefixHeaders, is HttpBindingSection.AfterDeserializingIntoADateTimeOfHttpHeaders, - is HttpBindingSection.BeforeCreatingEventStreamReceiver, -> emptySection } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt index ea875699843..8d5e3748672 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt @@ -9,13 +9,11 @@ import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait -import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.protocol.traits.Rpcv2CborTrait import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolLoader import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap @@ -60,44 +58,39 @@ class StreamPayloadSerializerCustomization : ServerHttpBoundProtocolCustomizatio class ServerProtocolLoader(supportedProtocols: ProtocolMap) : ProtocolLoader(supportedProtocols) { companion object { - fun defaultProtocols( - httpBindingCustomizations: (ShapeId) -> List = { _ -> listOf() }, - ) = mapOf( - RestJson1Trait.ID to - ServerRestJsonFactory( - additionalServerHttpBoundProtocolCustomizations = - listOf( - StreamPayloadSerializerCustomization(), - ), - additionalHttpBindingCustomizations = httpBindingCustomizations(RestJson1Trait.ID), - ), - RestXmlTrait.ID to - ServerRestXmlFactory( - additionalServerHttpBoundProtocolCustomizations = - listOf( - StreamPayloadSerializerCustomization(), - ), - ), - AwsJson1_0Trait.ID to - ServerAwsJsonFactory( - AwsJsonVersion.Json10, - additionalServerHttpBoundProtocolCustomizations = listOf(StreamPayloadSerializerCustomization()), - additionalHttpBindingCustomizations = httpBindingCustomizations(AwsJson1_0Trait.ID), - ), - AwsJson1_1Trait.ID to - ServerAwsJsonFactory( - AwsJsonVersion.Json11, - additionalServerHttpBoundProtocolCustomizations = listOf(StreamPayloadSerializerCustomization()), - additionalHttpBindingCustomizations = httpBindingCustomizations(AwsJson1_1Trait.ID), - ), - Rpcv2CborTrait.ID to - ServerRpcV2CborFactory( - additionalServerHttpBoundProtocolCustomizations = - listOf( - StreamPayloadSerializerCustomization(), - ), - additionalHttpBindingCustomizations = httpBindingCustomizations(Rpcv2CborTrait.ID), - ), - ) + val DefaultProtocols = + mapOf( + RestJson1Trait.ID to + ServerRestJsonFactory( + additionalServerHttpBoundProtocolCustomizations = + listOf( + StreamPayloadSerializerCustomization(), + ), + ), + RestXmlTrait.ID to + ServerRestXmlFactory( + additionalServerHttpBoundProtocolCustomizations = + listOf( + StreamPayloadSerializerCustomization(), + ), + ), + AwsJson1_0Trait.ID to + ServerAwsJsonFactory( + AwsJsonVersion.Json10, + additionalServerHttpBoundProtocolCustomizations = listOf(StreamPayloadSerializerCustomization()), + ), + AwsJson1_1Trait.ID to + ServerAwsJsonFactory( + AwsJsonVersion.Json11, + additionalServerHttpBoundProtocolCustomizations = listOf(StreamPayloadSerializerCustomization()), + ), + Rpcv2CborTrait.ID to + ServerRpcV2CborFactory( + additionalServerHttpBoundProtocolCustomizations = + listOf( + StreamPayloadSerializerCustomization(), + ), + ), + ) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt index 840d82720db..a03dfbdd4e3 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt @@ -135,7 +135,7 @@ fun serverTestCodegenContext( fun loadServerProtocol(model: Model): ServerProtocol { val codegenContext = serverTestCodegenContext(model) val (_, protocolGeneratorFactory) = - ServerProtocolLoader(ServerProtocolLoader.defaultProtocols()).protocolFor(model, codegenContext.serviceShape) + ServerProtocolLoader(ServerProtocolLoader.DefaultProtocols).protocolFor(model, codegenContext.serviceShape) return protocolGeneratorFactory.buildProtocolGenerator(codegenContext).protocol } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamSupportStructuresTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamSupportStructuresTest.kt index 2f09c18ffcc..059092889a9 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamSupportStructuresTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SigV4EventStreamSupportStructuresTest.kt @@ -6,7 +6,6 @@ package software.amazon.smithy.rust.codegen.server.smithy.customizations import org.junit.jupiter.api.Test -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace @@ -19,7 +18,7 @@ class SigV4EventStreamSupportStructuresTest { @Test fun `support structures compile`() { val project = TestWorkspace.testProject() - project.withModule(RustModule.private("sigv4_event_stream")) { + project.lib { val codegenScope = SigV4EventStreamSupportStructures.codegenScope(runtimeConfig) // Generate the support structures - RuntimeType.forInlineFun automatically generates the code diff --git a/rust-runtime/Cargo.lock b/rust-runtime/Cargo.lock index 8dcab055eeb..8b9a942a32f 100644 --- a/rust-runtime/Cargo.lock +++ b/rust-runtime/Cargo.lock @@ -396,7 +396,7 @@ version = "0.2.1" [[package]] name = "aws-smithy-http" -version = "0.62.5" +version = "0.62.6" dependencies = [ "async-stream", "aws-smithy-eventstream", @@ -539,7 +539,7 @@ dependencies = [ [[package]] name = "aws-smithy-mocks" -version = "0.2.0" +version = "0.2.1" dependencies = [ "aws-smithy-async", "aws-smithy-http-client", diff --git a/rust-runtime/aws-smithy-http/Cargo.toml b/rust-runtime/aws-smithy-http/Cargo.toml index 2df48b85b7c..6861e5fa506 100644 --- a/rust-runtime/aws-smithy-http/Cargo.toml +++ b/rust-runtime/aws-smithy-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-http" -version = "0.62.5" +version = "0.62.6" authors = [ "AWS Rust SDK Team ", "Russell Cohen ", diff --git a/rust-runtime/aws-smithy-http/src/event_stream/receiver.rs b/rust-runtime/aws-smithy-http/src/event_stream/receiver.rs index 733b8a6610b..07351f2ba98 100644 --- a/rust-runtime/aws-smithy-http/src/event_stream/receiver.rs +++ b/rust-runtime/aws-smithy-http/src/event_stream/receiver.rs @@ -6,7 +6,7 @@ use aws_smithy_eventstream::frame::{ DecodedFrame, MessageFrameDecoder, UnmarshallMessage, UnmarshalledMessage, }; -use aws_smithy_runtime_api::client::result::{ConnectorError, SdkError}; +use aws_smithy_runtime_api::client::result::{ConnectorError, ResponseError, SdkError}; use aws_smithy_types::body::SdkBody; use aws_smithy_types::event_stream::{Message, RawMessage}; use bytes::Buf; @@ -229,8 +229,31 @@ impl Receiver { &mut self, message_type: InitialMessageType, ) -> Result, SdkError> { + self.try_recv_initial_with_preprocessor(message_type, |msg| Ok((msg, ()))) + .await + .map(|opt| opt.map(|(msg, _)| msg)) + } + + /// Tries to receive the initial response message with preprocessing support. + /// + /// The preprocessor function can transform the raw message (e.g., unwrap envelopes) + /// and return metadata alongside the transformed message. If the transformed message + /// matches the expected `message_type`, both the message and metadata are returned. + /// Otherwise, the transformed message is buffered and `Ok(None)` is returned. + #[doc(hidden)] + pub async fn try_recv_initial_with_preprocessor( + &mut self, + message_type: InitialMessageType, + preprocessor: F, + ) -> Result, SdkError> + where + F: FnOnce(Message) -> Result<(Message, M), ResponseError>, + { if let Some(message) = self.next_message().await? { - if let Some(event_type) = message + let (processed_message, metadata) = + preprocessor(message.clone()).map_err(|err| SdkError::ResponseError(err))?; + + if let Some(event_type) = processed_message .headers() .iter() .find(|h| h.name().as_str() == ":event-type") @@ -241,10 +264,10 @@ impl Receiver { .map(|s| s.as_str() == message_type.as_str()) .unwrap_or(false) { - return Ok(Some(message)); + return Ok(Some((processed_message, metadata))); } } - // Buffer the message so that it can be returned by the next call to `recv()` + // Buffer the processed message so that it can be returned by the next call to `recv()` self.buffered_message = Some(message); } Ok(None) diff --git a/tools/ci-cdk/canary-runner/Cargo.lock b/tools/ci-cdk/canary-runner/Cargo.lock index a145ab9d446..c27f946f053 100644 --- a/tools/ci-cdk/canary-runner/Cargo.lock +++ b/tools/ci-cdk/canary-runner/Cargo.lock @@ -2650,7 +2650,7 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "smithy-rs-tool-common" -version = "0.1.0" +version = "0.1.1" dependencies = [ "anyhow", "async-trait",