From 784a201e8f90a3a4a0456768c4c6d6b326090d56 Mon Sep 17 00:00:00 2001 From: "Zifan.Xu" Date: Mon, 25 Nov 2019 16:05:12 -0500 Subject: [PATCH 1/5] Add filter class to dask and do the tests for it --- streamz/dask.py | 35 ++++++++++++++++++++++ streamz/tests/test_dask.py | 61 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/streamz/dask.py b/streamz/dask.py index d0c9d4e2..b0cc66aa 100644 --- a/streamz/dask.py +++ b/streamz/dask.py @@ -1,5 +1,7 @@ from __future__ import absolute_import, division, print_function +from functools import wraps +from .core import _truthy from operator import getitem from tornado import gen @@ -11,6 +13,21 @@ from . import core, sources +NULL_COMPUTE = "~~NULL_COMPUTE~~" + + +def return_null(func): + @wraps(func) + def inner(x, *args, **kwargs): + tv = func(x, *args, **kwargs) + if tv: + return x + else: + return NULL_COMPUTE + + return inner + + class DaskStream(Stream): """ A Parallel stream using Dask @@ -140,6 +157,24 @@ def update(self, x, who=None): return self._emit(result) +@DaskStream.register_api() +class filter(DaskStream): + def __init__(self, upstream, predicate, *args, **kwargs): + if predicate is None: + predicate = _truthy + self.predicate = return_null(predicate) + stream_name = kwargs.pop("stream_name", None) + self.kwargs = kwargs + self.args = args + + DaskStream.__init__(self, upstream, stream_name=stream_name) + + def update(self, x, who=None): + client = self.default_client() + result = client.submit(self.predicate, x, *self.args, **self.kwargs) + return self._emit(result) + + @DaskStream.register_api() class buffer(DaskStream, core.buffer): pass diff --git a/streamz/tests/test_dask.py b/streamz/tests/test_dask.py index d4da4fc4..465a7056 100644 --- a/streamz/tests/test_dask.py +++ b/streamz/tests/test_dask.py @@ -131,6 +131,67 @@ def test_buffer(c, s, a, b): assert source.loop == c.loop +@pytest.mark.slow +def test_filter(backend): + source = Stream(asynchronous=True) + futures = scatter(source, backend=backend).filter(lambda x: x % 2 == 0) + futures_L = futures.sink_to_list() + L = futures.gather().sink_to_list() + + for i in range(5): + yield source.emit(i) + + assert L == [0, 2, 4] + assert all(isinstance(f, Future) for f in futures_L) + + +@pytest.mark.slow +def test_filter_buffer(backend): + source = Stream(asynchronous=True) + futures = scatter(source, backend=backend).filter(lambda x: x % 2 == 0) + futures_L = futures.sink_to_list() + L = futures.buffer(10).gather().sink_to_list() + + for i in range(5): + yield source.emit(i) + while len(L) < 3: + yield gen.sleep(.01) + + assert L == [0, 2, 4] + assert all(isinstance(f, Future) for f in futures_L) + + +@pytest.mark.slow +def test_filter_map(backend): + source = Stream(asynchronous=True) + futures = ( + scatter(source, backend=backend).filter(lambda x: x % 2 == 0).map(inc) + ) + futures_L = futures.sink_to_list() + L = futures.gather().sink_to_list() + + for i in range(5): + yield source.emit(i) + + assert L == [1, 3, 5] + assert all(isinstance(f, Future) for f in futures_L) + + +@pytest.mark.slow +def test_filter_starmap(backend): + source = Stream(asynchronous=True) + futures1 = scatter(source, backend=backend).filter(lambda x: x[1] % 2 == 0) + futures = futures1.starmap(add) + futures_L = futures.sink_to_list() + L = futures.gather().sink_to_list() + + for i in range(5): + yield source.emit((i, i)) + + assert L == [0, 4, 8] + assert all(isinstance(f, Future) for f in futures_L) + + @pytest.mark.slow def test_buffer_sync(loop): # noqa: F811 with cluster() as (s, [a, b]): From bd7024ef0fa09a4a41900de149adec3379dea0f0 Mon Sep 17 00:00:00 2001 From: "Zifan.Xu" Date: Tue, 26 Nov 2019 11:02:31 -0500 Subject: [PATCH 2/5] Changed the gather class and removed backend from test_filter --- streamz/clients.py | 96 ++++++++++++++++++++++++++++++++++++++ streamz/dask.py | 41 ++++++++++++++-- streamz/tests/test_dask.py | 4 +- 3 files changed, 136 insertions(+), 5 deletions(-) create mode 100644 streamz/clients.py diff --git a/streamz/clients.py b/streamz/clients.py new file mode 100644 index 00000000..94ab618f --- /dev/null +++ b/streamz/clients.py @@ -0,0 +1,96 @@ +from collections import Sequence, MutableMapping +from concurrent.futures import ThreadPoolExecutor, Future +from functools import wraps + +from distributed import default_client as dask_default_client +from tornado import gen + +from .core import identity + + +FILL_COLOR_LOOKUP = {"dask": "cornflowerblue", "threads": "coral"} + + +def result_maybe(future_maybe): + if isinstance(future_maybe, Future): + return future_maybe.result() + else: + if isinstance(future_maybe, Sequence) and not isinstance( + future_maybe, str + ): + aa = [] + for a in future_maybe: + aa.append(result_maybe(a)) + if isinstance(future_maybe, tuple): + aa = tuple(aa) + return aa + elif isinstance(future_maybe, MutableMapping): + for k, v in future_maybe.items(): + future_maybe[k] = result_maybe(v) + return future_maybe + + +def delayed_execution(func): + @wraps(func) + def inner(*args, **kwargs): + args = tuple([result_maybe(v) for v in args]) + kwargs = {k: result_maybe(v) for k, v in kwargs.items()} + return func(*args, **kwargs) + + return inner + + +def executor_to_client(executor): + executor._submit = executor.submit + + @wraps(executor.submit) + def inner(fn, *args, **kwargs): + wfn = delayed_execution(fn) + return executor._submit(wfn, *args, **kwargs) + + executor.submit = inner + + @gen.coroutine + def scatter(x, asynchronous=True): + f = executor.submit(identity, x) + return f + + executor.scatter = getattr(executor, "scatter", scatter) + + @gen.coroutine + def gather(x, asynchronous=True): + # If we have a sequence of futures await each one + if isinstance(x, Sequence): + final_result = [] + for sub_x in x: + yx = yield sub_x + final_result.append(yx) + result = type(x)(final_result) + else: + result = yield x + return result + + executor.gather = getattr(executor, "gather", gather) + return executor + + +thread_ex_list = [] + + +def thread_default_client(): + if thread_ex_list: + ex = thread_ex_list[0] + if ex._shutdown: + thread_ex_list.pop() + ex = executor_to_client(ThreadPoolExecutor()) + thread_ex_list.append(ex) + else: + ex = executor_to_client(ThreadPoolExecutor()) + thread_ex_list.append(ex) + return ex + + +DEFAULT_BACKENDS = { + "dask": dask_default_client, + "thread": thread_default_client, +} diff --git a/streamz/dask.py b/streamz/dask.py index b0cc66aa..5b2cd19d 100644 --- a/streamz/dask.py +++ b/streamz/dask.py @@ -2,6 +2,8 @@ from functools import wraps from .core import _truthy +from .core import get_io_loop +from .clients import DEFAULT_BACKENDS from operator import getitem from tornado import gen @@ -12,6 +14,8 @@ from .core import Stream from . import core, sources +from collections import Sequence + NULL_COMPUTE = "~~NULL_COMPUTE~~" @@ -134,12 +138,43 @@ class gather(core.Stream): buffer scatter """ + + def __init__(self, *args, backend="dask", **kwargs): + super().__init__(*args, **kwargs) + upstream_backends = set( + [getattr(u, "default_client", None) for u in self.upstreams] + ) + if None in upstream_backends: + upstream_backends.remove(None) + if len(upstream_backends) > 1: + raise RuntimeError("Mixing backends is not supported") + elif upstream_backends: + self.default_client = upstream_backends.pop() + else: + self.default_client = DEFAULT_BACKENDS.get(backend, backend) + if "loop" not in kwargs and getattr( + self.default_client(), "loop", None + ): + loop = self.default_client().loop + self._set_loop(loop) + if kwargs.get("ensure_io_loop", False) and not self.loop: + self._set_asynchronous(False) + if self.loop is None and self.asynchronous is not None: + self._set_loop(get_io_loop(self.asynchronous)) + @gen.coroutine def update(self, x, who=None): - client = default_client() + client = self.default_client() result = yield client.gather(x, asynchronous=True) - result2 = yield self._emit(result) - raise gen.Return(result2) + if ( + not ( + isinstance(result, Sequence) + and any(r == NULL_COMPUTE for r in result) + ) + and result != NULL_COMPUTE + ): + result2 = yield self._emit(result) + raise gen.Return(result2) @DaskStream.register_api() diff --git a/streamz/tests/test_dask.py b/streamz/tests/test_dask.py index 465a7056..7b09c14a 100644 --- a/streamz/tests/test_dask.py +++ b/streamz/tests/test_dask.py @@ -132,9 +132,9 @@ def test_buffer(c, s, a, b): @pytest.mark.slow -def test_filter(backend): +def test_filter(): source = Stream(asynchronous=True) - futures = scatter(source, backend=backend).filter(lambda x: x % 2 == 0) + futures = scatter(source).filter(lambda x: x % 2 == 0) futures_L = futures.sink_to_list() L = futures.gather().sink_to_list() From 8e88224c3ad3b8f7cac0b3dc0c32e3abe534c09d Mon Sep 17 00:00:00 2001 From: "Zifan.Xu" Date: Tue, 26 Nov 2019 11:44:14 -0500 Subject: [PATCH 3/5] Removed backend for all the filter tests --- streamz/tests/test_dask.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/streamz/tests/test_dask.py b/streamz/tests/test_dask.py index 7b09c14a..9be7ff57 100644 --- a/streamz/tests/test_dask.py +++ b/streamz/tests/test_dask.py @@ -146,9 +146,9 @@ def test_filter(): @pytest.mark.slow -def test_filter_buffer(backend): +def test_filter_buffer(): source = Stream(asynchronous=True) - futures = scatter(source, backend=backend).filter(lambda x: x % 2 == 0) + futures = scatter(source).filter(lambda x: x % 2 == 0) futures_L = futures.sink_to_list() L = futures.buffer(10).gather().sink_to_list() @@ -162,10 +162,10 @@ def test_filter_buffer(backend): @pytest.mark.slow -def test_filter_map(backend): +def test_filter_map(): source = Stream(asynchronous=True) futures = ( - scatter(source, backend=backend).filter(lambda x: x % 2 == 0).map(inc) + scatter(source).filter(lambda x: x % 2 == 0).map(inc) ) futures_L = futures.sink_to_list() L = futures.gather().sink_to_list() @@ -178,9 +178,9 @@ def test_filter_map(backend): @pytest.mark.slow -def test_filter_starmap(backend): +def test_filter_starmap(): source = Stream(asynchronous=True) - futures1 = scatter(source, backend=backend).filter(lambda x: x[1] % 2 == 0) + futures1 = scatter(source).filter(lambda x: x[1] % 2 == 0) futures = futures1.starmap(add) futures_L = futures.sink_to_list() L = futures.gather().sink_to_list() From 721d4af2ec20e83bd7de4ba9ac5d7fdd89819061 Mon Sep 17 00:00:00 2001 From: "Zifan.Xu" Date: Tue, 26 Nov 2019 16:26:26 -0500 Subject: [PATCH 4/5] Deleted clients, fixed class map and filter --- streamz/clients.py | 96 ---------------------------------------------- streamz/dask.py | 44 ++++++++------------- 2 files changed, 16 insertions(+), 124 deletions(-) delete mode 100644 streamz/clients.py diff --git a/streamz/clients.py b/streamz/clients.py deleted file mode 100644 index 94ab618f..00000000 --- a/streamz/clients.py +++ /dev/null @@ -1,96 +0,0 @@ -from collections import Sequence, MutableMapping -from concurrent.futures import ThreadPoolExecutor, Future -from functools import wraps - -from distributed import default_client as dask_default_client -from tornado import gen - -from .core import identity - - -FILL_COLOR_LOOKUP = {"dask": "cornflowerblue", "threads": "coral"} - - -def result_maybe(future_maybe): - if isinstance(future_maybe, Future): - return future_maybe.result() - else: - if isinstance(future_maybe, Sequence) and not isinstance( - future_maybe, str - ): - aa = [] - for a in future_maybe: - aa.append(result_maybe(a)) - if isinstance(future_maybe, tuple): - aa = tuple(aa) - return aa - elif isinstance(future_maybe, MutableMapping): - for k, v in future_maybe.items(): - future_maybe[k] = result_maybe(v) - return future_maybe - - -def delayed_execution(func): - @wraps(func) - def inner(*args, **kwargs): - args = tuple([result_maybe(v) for v in args]) - kwargs = {k: result_maybe(v) for k, v in kwargs.items()} - return func(*args, **kwargs) - - return inner - - -def executor_to_client(executor): - executor._submit = executor.submit - - @wraps(executor.submit) - def inner(fn, *args, **kwargs): - wfn = delayed_execution(fn) - return executor._submit(wfn, *args, **kwargs) - - executor.submit = inner - - @gen.coroutine - def scatter(x, asynchronous=True): - f = executor.submit(identity, x) - return f - - executor.scatter = getattr(executor, "scatter", scatter) - - @gen.coroutine - def gather(x, asynchronous=True): - # If we have a sequence of futures await each one - if isinstance(x, Sequence): - final_result = [] - for sub_x in x: - yx = yield sub_x - final_result.append(yx) - result = type(x)(final_result) - else: - result = yield x - return result - - executor.gather = getattr(executor, "gather", gather) - return executor - - -thread_ex_list = [] - - -def thread_default_client(): - if thread_ex_list: - ex = thread_ex_list[0] - if ex._shutdown: - thread_ex_list.pop() - ex = executor_to_client(ThreadPoolExecutor()) - thread_ex_list.append(ex) - else: - ex = executor_to_client(ThreadPoolExecutor()) - thread_ex_list.append(ex) - return ex - - -DEFAULT_BACKENDS = { - "dask": dask_default_client, - "thread": thread_default_client, -} diff --git a/streamz/dask.py b/streamz/dask.py index 5b2cd19d..48a35a77 100644 --- a/streamz/dask.py +++ b/streamz/dask.py @@ -2,8 +2,6 @@ from functools import wraps from .core import _truthy -from .core import get_io_loop -from .clients import DEFAULT_BACKENDS from operator import getitem from tornado import gen @@ -32,6 +30,19 @@ def inner(x, *args, **kwargs): return inner +def filter_null_wrapper(func): + @wraps(func) + def inner(*args, **kwargs): + if any(a == NULL_COMPUTE for a in args) or any( + v == NULL_COMPUTE for v in kwargs.values() + ): + return NULL_COMPUTE + else: + return func(*args, **kwargs) + + return inner + + class DaskStream(Stream): """ A Parallel stream using Dask @@ -67,7 +78,7 @@ def __init__(self, *args, **kwargs): @DaskStream.register_api() class map(DaskStream): def __init__(self, upstream, func, *args, **kwargs): - self.func = func + self.func = filter_null_wrapper(func) self.kwargs = kwargs self.args = args @@ -139,32 +150,9 @@ class gather(core.Stream): scatter """ - def __init__(self, *args, backend="dask", **kwargs): - super().__init__(*args, **kwargs) - upstream_backends = set( - [getattr(u, "default_client", None) for u in self.upstreams] - ) - if None in upstream_backends: - upstream_backends.remove(None) - if len(upstream_backends) > 1: - raise RuntimeError("Mixing backends is not supported") - elif upstream_backends: - self.default_client = upstream_backends.pop() - else: - self.default_client = DEFAULT_BACKENDS.get(backend, backend) - if "loop" not in kwargs and getattr( - self.default_client(), "loop", None - ): - loop = self.default_client().loop - self._set_loop(loop) - if kwargs.get("ensure_io_loop", False) and not self.loop: - self._set_asynchronous(False) - if self.loop is None and self.asynchronous is not None: - self._set_loop(get_io_loop(self.asynchronous)) - @gen.coroutine def update(self, x, who=None): - client = self.default_client() + client = default_client() result = yield client.gather(x, asynchronous=True) if ( not ( @@ -205,7 +193,7 @@ def __init__(self, upstream, predicate, *args, **kwargs): DaskStream.__init__(self, upstream, stream_name=stream_name) def update(self, x, who=None): - client = self.default_client() + client = default_client() result = client.submit(self.predicate, x, *self.args, **self.kwargs) return self._emit(result) From b7472e66f1a975cf77046864df91d8460f7138cb Mon Sep 17 00:00:00 2001 From: "Zifan.Xu" Date: Thu, 5 Dec 2019 16:16:37 -0500 Subject: [PATCH 5/5] Change "==" to is and "!=" to is not to use hash --- streamz/dask.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/streamz/dask.py b/streamz/dask.py index 48a35a77..ae11a7a2 100644 --- a/streamz/dask.py +++ b/streamz/dask.py @@ -33,8 +33,8 @@ def inner(x, *args, **kwargs): def filter_null_wrapper(func): @wraps(func) def inner(*args, **kwargs): - if any(a == NULL_COMPUTE for a in args) or any( - v == NULL_COMPUTE for v in kwargs.values() + if any(a is NULL_COMPUTE for a in args) or any( + v is NULL_COMPUTE for v in kwargs.values() ): return NULL_COMPUTE else: @@ -157,9 +157,9 @@ def update(self, x, who=None): if ( not ( isinstance(result, Sequence) - and any(r == NULL_COMPUTE for r in result) + and any(r is NULL_COMPUTE for r in result) ) - and result != NULL_COMPUTE + and result is not NULL_COMPUTE ): result2 = yield self._emit(result) raise gen.Return(result2)