Skip to content

Commit e674c93

Browse files
Enable more integration tests (#8)
* Add ListObject * Fix bad discovery in workflow case * Fix error conversions, no need for the anyhow feature anymore * More tests enabled
1 parent 5036b7a commit e674c93

File tree

11 files changed

+177
-19
lines changed

11 files changed

+177
-19
lines changed

Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@ license = "MIT"
77
repository = "https://github.com/restatedev/sdk-rust"
88

99
[features]
10-
default = ["http", "anyhow"]
10+
default = ["http"]
1111
http = ["hyper", "http-body-util", "hyper-util", "tokio/net", "tokio/signal", "restate-sdk-shared-core/http"]
1212

1313
[dependencies]
14-
anyhow = {version = "1.0", optional = true}
1514
bytes = "1.6.1"
1615
futures = "0.3"
1716
http-body-util = { version = "0.1", optional = true }

macros/src/gen.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ impl<'a> ServiceGenerator<'a> {
178178

179179
let service_literal = Literal::string(restate_name);
180180

181-
let service_ty = match service_ty {
181+
let service_ty_token = match service_ty {
182182
ServiceType::Service => quote! { ::restate_sdk::discovery::ServiceType::Service },
183183
ServiceType::Object => {
184184
quote! { ::restate_sdk::discovery::ServiceType::VirtualObject }
@@ -191,6 +191,8 @@ impl<'a> ServiceGenerator<'a> {
191191

192192
let handler_ty = if handler.is_shared {
193193
quote! { Some(::restate_sdk::discovery::HandlerType::Shared) }
194+
} else if *service_ty == ServiceType::Workflow {
195+
quote! { Some(::restate_sdk::discovery::HandlerType::Workflow) }
194196
} else {
195197
// Macro has same defaulting rules of the discovery manifest
196198
quote! { None }
@@ -212,7 +214,7 @@ impl<'a> ServiceGenerator<'a> {
212214
{
213215
fn discover() -> ::restate_sdk::discovery::Service {
214216
::restate_sdk::discovery::Service {
215-
ty: #service_ty,
217+
ty: #service_ty_token,
216218
name: ::restate_sdk::discovery::ServiceName::try_from(#service_literal.to_string())
217219
.expect("Service name valid"),
218220
handlers: vec![#( #handlers ),*],

src/errors.rs

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use restate_sdk_shared_core::Failure;
22
use std::error::Error as StdError;
33
use std::fmt;
4-
use thiserror::__private::AsDynError;
54

65
#[derive(Debug)]
76
pub(crate) enum HandlerErrorInner {
@@ -23,7 +22,7 @@ impl fmt::Display for HandlerErrorInner {
2322
impl StdError for HandlerErrorInner {
2423
fn source(&self) -> Option<&(dyn StdError + 'static)> {
2524
match self {
26-
HandlerErrorInner::Retryable(e) => Some(e.as_dyn_error()),
25+
HandlerErrorInner::Retryable(e) => Some(e.as_ref()),
2726
HandlerErrorInner::Terminal(e) => Some(e),
2827
}
2928
}
@@ -32,16 +31,9 @@ impl StdError for HandlerErrorInner {
3231
#[derive(Debug)]
3332
pub struct HandlerError(pub(crate) HandlerErrorInner);
3433

35-
impl HandlerError {
36-
#[cfg(feature = "anyhow")]
37-
pub fn from_anyhow(err: anyhow::Error) -> Self {
38-
Self(HandlerErrorInner::Retryable(err.into()))
39-
}
40-
}
41-
42-
impl<E: StdError + Send + Sync + 'static> From<E> for HandlerError {
34+
impl<E: Into<Box<dyn StdError + Send + Sync + 'static>>> From<E> for HandlerError {
4335
fn from(value: E) -> Self {
44-
Self(HandlerErrorInner::Retryable(Box::new(value)))
36+
Self(HandlerErrorInner::Retryable(value.into()))
4537
}
4638
}
4739

src/serde.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,9 @@ where
174174
serde_json::from_slice(bytes).map(Json)
175175
}
176176
}
177+
178+
impl<T: Default> Default for Json<T> {
179+
fn default() -> Self {
180+
Self(T::default())
181+
}
182+
}

test-services/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Test services
2+
3+
To build (from the repo root):
4+
5+
```shell
6+
$ podman build -f test-services/Dockerfile -t restatedev/rust-test-services .
7+
```
8+
9+
To run (download the [sdk-test-suite](https://github.com/restatedev/sdk-test-suite) first):
10+
11+
```shell
12+
$ java -jar restate-sdk-test-suite.jar run restatedev/rust-test-services
13+
```

test-services/exclusions.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ exclusions:
88
- "dev.restate.sdktesting.tests.UpgradeWithInFlightInvocation"
99
- "dev.restate.sdktesting.tests.UpgradeWithNewInvocation"
1010
- "dev.restate.sdktesting.tests.UserErrors"
11-
- "dev.restate.sdktesting.tests.WorkflowAPI"
1211
"default":
1312
- "dev.restate.sdktesting.tests.AwaitTimeout"
1413
- "dev.restate.sdktesting.tests.CallOrdering"
@@ -21,7 +20,6 @@ exclusions:
2120
- "dev.restate.sdktesting.tests.UpgradeWithInFlightInvocation"
2221
- "dev.restate.sdktesting.tests.UpgradeWithNewInvocation"
2322
- "dev.restate.sdktesting.tests.UserErrors"
24-
- "dev.restate.sdktesting.tests.WorkflowAPI"
2523
"persistedTimers":
2624
- "dev.restate.sdktesting.tests.Sleep"
2725
"singleThreadSinglePartition":
@@ -36,4 +34,3 @@ exclusions:
3634
- "dev.restate.sdktesting.tests.UpgradeWithInFlightInvocation"
3735
- "dev.restate.sdktesting.tests.UpgradeWithNewInvocation"
3836
- "dev.restate.sdktesting.tests.UserErrors"
39-
- "dev.restate.sdktesting.tests.WorkflowAPI"
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
use restate_sdk::prelude::*;
2+
3+
#[restate_sdk::object]
4+
#[name = "AwakeableHolder"]
5+
pub(crate) trait AwakeableHolder {
6+
#[name = "hold"]
7+
async fn hold(id: String) -> HandlerResult<()>;
8+
#[name = "hasAwakeable"]
9+
#[shared]
10+
async fn has_awakeable() -> HandlerResult<bool>;
11+
#[name = "unlock"]
12+
async fn unlock(payload: String) -> HandlerResult<()>;
13+
}
14+
15+
pub(crate) struct AwakeableHolderImpl;
16+
17+
const ID: &str = "id";
18+
19+
impl AwakeableHolder for AwakeableHolderImpl {
20+
async fn hold(&self, context: ObjectContext<'_>, id: String) -> HandlerResult<()> {
21+
context.set(ID, id);
22+
Ok(())
23+
}
24+
25+
async fn has_awakeable(&self, context: SharedObjectContext<'_>) -> HandlerResult<bool> {
26+
Ok(context.get::<String>(ID).await?.is_some())
27+
}
28+
29+
async fn unlock(&self, context: ObjectContext<'_>, payload: String) -> HandlerResult<()> {
30+
let k: String = context.get(ID).await?.ok_or_else(|| {
31+
TerminalError::new(format!(
32+
"No awakeable stored for awakeable holder {}",
33+
context.key()
34+
))
35+
})?;
36+
context.resolve_awakeable(&k, payload);
37+
Ok(())
38+
}
39+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
use restate_sdk::prelude::*;
2+
3+
#[restate_sdk::workflow]
4+
#[name = "BlockAndWaitWorkflow"]
5+
pub(crate) trait BlockAndWaitWorkflow {
6+
#[name = "run"]
7+
async fn run(input: String) -> HandlerResult<String>;
8+
#[name = "unblock"]
9+
#[shared]
10+
async fn unblock(output: String) -> HandlerResult<()>;
11+
#[name = "getState"]
12+
#[shared]
13+
async fn get_state() -> HandlerResult<Json<Option<String>>>;
14+
}
15+
16+
pub(crate) struct BlockAndWaitWorkflowImpl;
17+
18+
const MY_PROMISE: &str = "my-promise";
19+
const MY_STATE: &str = "my-state";
20+
21+
impl BlockAndWaitWorkflow for BlockAndWaitWorkflowImpl {
22+
async fn run(&self, context: WorkflowContext<'_>, input: String) -> HandlerResult<String> {
23+
context.set(MY_STATE, input);
24+
25+
let promise: String = context.promise(MY_PROMISE).await?;
26+
27+
if context.peek_promise::<String>(MY_PROMISE).await?.is_none() {
28+
return Err(TerminalError::new("Durable promise should be completed").into());
29+
}
30+
31+
Ok(promise)
32+
}
33+
34+
async fn unblock(
35+
&self,
36+
context: SharedWorkflowContext<'_>,
37+
output: String,
38+
) -> HandlerResult<()> {
39+
context.resolve_promise(MY_PROMISE, output);
40+
Ok(())
41+
}
42+
43+
async fn get_state(
44+
&self,
45+
context: SharedWorkflowContext<'_>,
46+
) -> HandlerResult<Json<Option<String>>> {
47+
Ok(Json(context.get::<String>(MY_STATE).await?))
48+
}
49+
}

test-services/src/list_object.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
use restate_sdk::prelude::*;
2+
3+
#[restate_sdk::object]
4+
#[name = "ListObject"]
5+
pub(crate) trait ListObject {
6+
#[name = "append"]
7+
async fn append(value: String) -> HandlerResult<()>;
8+
#[name = "get"]
9+
async fn get() -> HandlerResult<Json<Vec<String>>>;
10+
#[name = "clear"]
11+
async fn clear() -> HandlerResult<Json<Vec<String>>>;
12+
}
13+
14+
pub(crate) struct ListObjectImpl;
15+
16+
const LIST: &str = "list";
17+
18+
impl ListObject for ListObjectImpl {
19+
async fn append(&self, ctx: ObjectContext<'_>, value: String) -> HandlerResult<()> {
20+
let mut list = ctx
21+
.get::<Json<Vec<String>>>(LIST)
22+
.await?
23+
.unwrap_or_default()
24+
.into_inner();
25+
list.push(value);
26+
ctx.set(LIST, Json(list));
27+
Ok(())
28+
}
29+
30+
async fn get(&self, ctx: ObjectContext<'_>) -> HandlerResult<Json<Vec<String>>> {
31+
Ok(ctx
32+
.get::<Json<Vec<String>>>(LIST)
33+
.await?
34+
.unwrap_or_default())
35+
}
36+
37+
async fn clear(&self, ctx: ObjectContext<'_>) -> HandlerResult<Json<Vec<String>>> {
38+
let get = ctx
39+
.get::<Json<Vec<String>>>(LIST)
40+
.await?
41+
.unwrap_or_default();
42+
ctx.clear(LIST);
43+
Ok(get)
44+
}
45+
}

test-services/src/main.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
mod awakeable_holder;
2+
mod block_and_wait_workflow;
13
mod counter;
4+
mod list_object;
25
mod map_object;
36
mod proxy;
47

@@ -22,6 +25,19 @@ async fn main() {
2225
if services == "*" || services.contains("MapObject") {
2326
builder = builder.with_service(map_object::MapObject::serve(map_object::MapObjectImpl))
2427
}
28+
if services == "*" || services.contains("ListObject") {
29+
builder = builder.with_service(list_object::ListObject::serve(list_object::ListObjectImpl))
30+
}
31+
if services == "*" || services.contains("AwakeableHolder") {
32+
builder = builder.with_service(awakeable_holder::AwakeableHolder::serve(
33+
awakeable_holder::AwakeableHolderImpl,
34+
))
35+
}
36+
if services == "*" || services.contains("BlockAndWaitWorkflow") {
37+
builder = builder.with_service(block_and_wait_workflow::BlockAndWaitWorkflow::serve(
38+
block_and_wait_workflow::BlockAndWaitWorkflowImpl,
39+
))
40+
}
2541

2642
HyperServer::new(builder.build())
2743
.listen_and_serve(format!("0.0.0.0:{port}").parse().unwrap())

0 commit comments

Comments
 (0)