@@ -8,6 +8,7 @@ import mock
88import grpc
99from grpc.experimental import aio
1010{% if "rest" in opts .transport %}
11+ from collections.abc import Iterable
1112import json
1213{% endif %}
1314import math
@@ -861,8 +862,8 @@ def test_{{ method_name }}_raw_page_lro():
861862{% endfor %} {# method in methods for grpc #}
862863
863864{% for method in service .methods .values () if 'rest' in opts .transport %}{% with method_name = method .name |snake_case + "_unary" if method .operation_service else method .name |snake_case %}{% if method .http_options %}
864- {# TODO(kbandes): remove this if condition when streaming are supported. #}
865- {% if not ( method .server_streaming or method . client_streaming ) %}
865+ {# TODO(kbandes): remove this if condition when client streaming are supported. #}
866+ {% if not method .client_streaming %}
866867@pytest.mark.parametrize("request_type", [
867868 {{ method.input.ident }},
868869 dict,
@@ -884,8 +885,6 @@ def test_{{ method_name }}_rest(request_type, transport: str = 'rest'):
884885 return_value = None
885886 {% elif method .lro %}
886887 return_value = operations_pb2.Operation(name='operations/spam')
887- {% elif method .server_streaming %}
888- return_value = iter([{{ method.output.ident }}()])
889888 {% else %}
890889 return_value = {{ method.output.ident }}(
891890 {% for field in method .output .fields .values () | rejectattr ('message' )%}
@@ -905,6 +904,8 @@ def test_{{ method_name }}_rest(request_type, transport: str = 'rest'):
905904 req.return_value.request = PreparedRequest()
906905 {% if method .void %}
907906 json_return_value = ''
907+ {% elif method .server_streaming %}
908+ json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
908909 {% else %}
909910 json_return_value = {{ method.output.ident }}.to_json(return_value)
910911 {% endif %}
@@ -914,6 +915,10 @@ def test_{{ method_name }}_rest(request_type, transport: str = 'rest'):
914915 # the request over the wire, so an empty request is fine.
915916 {% if method .client_streaming %}
916917 client.{{ method_name }}(iter([requests]))
918+ {% elif method .server_streaming %}
919+ with mock.patch.object(response_value, 'iter_content') as iter_content:
920+ iter_content.return_value = iter(json_return_value)
921+ response = client.{{ method_name }}(request)
917922 {% else %}
918923 client.{{ method_name }}(request)
919924 {% endif %}
@@ -950,8 +955,6 @@ def test_{{ method.name|snake_case }}_rest(request_type):
950955 return_value = None
951956 {% elif method .lro %}
952957 return_value = operations_pb2.Operation(name='operations/spam')
953- {% elif method .server_streaming %}
954- return_value = iter([{{ method.output.ident }}()])
955958 {% else %}
956959 return_value = {{ method.output.ident }}(
957960 {% for field in method .output .fields .values () | rejectattr ('message' )%}
@@ -974,13 +977,19 @@ def test_{{ method.name|snake_case }}_rest(request_type):
974977 json_return_value = ''
975978 {% elif method .lro %}
976979 json_return_value = json_format.MessageToJson(return_value)
980+ {% elif method .server_streaming %}
981+ json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
977982 {% else %}
978983 json_return_value = {{ method.output.ident }}.to_json(return_value)
979984 {% endif %}
980985 response_value._content = json_return_value.encode('UTF-8')
981986 req.return_value = response_value
982987 {% if method .client_streaming %}
983988 response = client.{{ method.name|snake_case }}(iter(requests))
989+ {% elif method .server_streaming %}
990+ with mock.patch.object(response_value, 'iter_content') as iter_content:
991+ iter_content.return_value = iter(json_return_value)
992+ response = client.{{ method_name }}(request)
984993 {% else %}
985994 response = client.{{ method_name }}(request)
986995 {% endif %}
@@ -991,6 +1000,11 @@ def test_{{ method.name|snake_case }}_rest(request_type):
9911000
9921001 {% endif %}
9931002
1003+ {% if method .server_streaming %}
1004+ assert isinstance(response, Iterable)
1005+ response = next(response)
1006+ {% endif %}
1007+
9941008 # Establish that the response is the type that we expect.
9951009 {% if method .void %}
9961010 assert response is None
@@ -1085,8 +1099,6 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
10851099 return_value = None
10861100 {% elif method .lro %}
10871101 return_value = operations_pb2.Operation(name='operations/spam')
1088- {% elif method .server_streaming %}
1089- return_value = iter([{{ method.output.ident }}()])
10901102 {% else %}
10911103 return_value = {{ method.output.ident }}()
10921104 {% endif %}
@@ -1114,6 +1126,8 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
11141126 json_return_value = ''
11151127 {% elif method .lro %}
11161128 json_return_value = json_format.MessageToJson(return_value)
1129+ {% elif method .server_streaming %}
1130+ json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
11171131 {% else %}
11181132 json_return_value = {{ method.output.ident }}.to_json(return_value)
11191133 {% endif %}
@@ -1122,6 +1136,10 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
11221136
11231137 {% if method .client_streaming %}
11241138 response = client.{{ method.name|snake_case }}(iter(requests))
1139+ {% elif method .server_streaming %}
1140+ with mock.patch.object(response_value, 'iter_content') as iter_content:
1141+ iter_content.return_value = iter(json_return_value)
1142+ response = client.{{ method_name }}(request)
11251143 {% else %}
11261144 response = client.{{ method_name }}(request)
11271145 {% endif %}
@@ -1248,8 +1266,6 @@ def test_{{ method.name|snake_case }}_rest_flattened():
12481266 return_value = None
12491267 {% elif method .lro %}
12501268 return_value = operations_pb2.Operation(name='operations/spam')
1251- {% elif method .server_streaming %}
1252- return_value = iter([{{ method.output.ident }}()])
12531269 {% else %}
12541270 return_value = {{ method.output.ident }}()
12551271 {% endif %}
@@ -1261,6 +1277,8 @@ def test_{{ method.name|snake_case }}_rest_flattened():
12611277 json_return_value = ''
12621278 {% elif method .lro %}
12631279 json_return_value = json_format.MessageToJson(return_value)
1280+ {% elif method .server_streaming %}
1281+ json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
12641282 {% else %}
12651283 json_return_value = {{ method.output.ident }}.to_json(return_value)
12661284 {% endif %}
@@ -1281,7 +1299,14 @@ def test_{{ method.name|snake_case }}_rest_flattened():
12811299 {% endfor %}
12821300 )
12831301 mock_args.update(sample_request)
1302+
1303+ {% if method .server_streaming %}
1304+ with mock.patch.object(response_value, 'iter_content') as iter_content:
1305+ iter_content.return_value = iter(json_return_value)
1306+ client.{{ method_name }}(**mock_args)
1307+ {% else %}
12841308 client.{{ method_name }}(**mock_args)
1309+ {% endif %}
12851310
12861311 # Establish that the underlying call was made with the expected
12871312 # request object values.
@@ -1385,6 +1410,9 @@ def test_{{ method_name }}_rest_pager(transport: str = 'rest'):
13851410 response = tuple({{ method.output.ident }}.to_json(x) for x in response)
13861411 return_values = tuple(Response() for i in response)
13871412 for return_val, response_val in zip(return_values, response):
1413+ {% if method .server_streaming %}
1414+ response_val = "[{}]".format({{ method.output.ident }}.to_json(response_val))
1415+ {% endif %}
13881416 return_val._content = response_val.encode('UTF-8')
13891417 return_val.status_code = 200
13901418 req.side_effect = return_values
0 commit comments