|
4 | 4 | from functools import partial, reduce |
5 | 5 |
|
6 | 6 | import mongoengine |
| 7 | +from graphene import PageInfo |
7 | 8 | from graphene.relay import ConnectionField |
8 | | -from graphene.relay.connection import PageInfo |
9 | 9 | from graphene.types.argument import to_arguments |
10 | 10 | from graphene.types.dynamic import Dynamic |
11 | 11 | from graphene.types.structures import Structure, List |
12 | | -from graphql_relay import from_global_id |
13 | 12 | from graphql_relay.connection.arrayconnection import connection_from_list_slice |
14 | 13 |
|
15 | 14 | from .advanced_types import PointFieldType, MultiPolygonFieldType |
16 | 15 | from .converter import convert_mongoengine_field, MongoEngineConversionError |
17 | 16 | from .registry import get_global_registry |
18 | | -from .utils import get_model_reference_fields, node_from_global_id |
| 17 | +from .utils import get_model_reference_fields, global_id_via_node |
19 | 18 |
|
20 | 19 |
|
21 | 20 | class MongoengineConnectionField(ConnectionField): |
22 | 21 |
|
23 | 22 | def __init__(self, type, *args, **kwargs): |
| 23 | + get_queryset = kwargs.pop('get_queryset', None) |
| 24 | + if get_queryset: |
| 25 | + assert callable(get_queryset), "Attribute `get_queryset` on {} must be callable.".format(self) |
| 26 | + self._get_queryset = get_queryset |
24 | 27 | super(MongoengineConnectionField, self).__init__( |
25 | 28 | type, |
26 | 29 | *args, |
@@ -109,91 +112,65 @@ def get_reference_field(r, kv): |
109 | 112 | def fields(self): |
110 | 113 | return self._type._meta.fields |
111 | 114 |
|
112 | | - @classmethod |
113 | | - def get_query(cls, model, connection, info, **args): |
114 | | - |
115 | | - if not callable(getattr(model, 'objects', None)): |
| 115 | + def get_queryset(self, model, info, **args): |
| 116 | + if self._get_queryset: |
| 117 | + queryset_or_filters = self._get_queryset(model, info, **args) |
| 118 | + if isinstance(queryset_or_filters, mongoengine.QuerySet): |
| 119 | + return queryset_or_filters |
| 120 | + else: |
| 121 | + return model.objects(**queryset_or_filters) |
| 122 | + return model.objects() |
| 123 | + |
| 124 | + def default_resolver(self, _root, info, **args): |
| 125 | + if not callable(getattr(self.model, 'objects', None)): |
116 | 126 | return [], 0 |
117 | 127 |
|
118 | | - objs = model.objects() |
| 128 | + args = args or {} |
| 129 | + |
| 130 | + connection_args = { |
| 131 | + 'first': args.pop('first', None), |
| 132 | + 'last': args.pop('last', None), |
| 133 | + 'before': args.pop('before', None), |
| 134 | + 'after': args.pop('after', None) |
| 135 | + } |
| 136 | + |
| 137 | + objs = self.get_queryset(self.model, info, **args) |
| 138 | + |
119 | 139 | if args: |
120 | | - reference_fields = get_model_reference_fields(model) |
| 140 | + reference_fields = get_model_reference_fields(self.model) |
121 | 141 | reference_args = {} |
122 | 142 | for arg_name, arg in args.copy().items(): |
123 | 143 | if arg_name in reference_fields: |
124 | | - reference_model = model._fields[arg_name] |
125 | | - pk = node_from_global_id(connection, args.pop(arg_name))[-1] |
| 144 | + reference_model = self.model._fields[arg_name] |
| 145 | + pk = global_id_via_node(self.node_type, args.pop(arg_name))[-1] |
126 | 146 | reference_obj = reference_model.document_type_obj.objects(pk=pk).get() |
127 | 147 | reference_args[arg_name] = reference_obj |
128 | 148 |
|
129 | 149 | args.update(reference_args) |
130 | | - first = args.pop('first', None) |
131 | | - last = args.pop('last', None) |
132 | 150 | _id = args.pop('id', None) |
133 | | - before = args.pop('before', None) |
134 | | - after = args.pop('after', None) |
135 | | - |
136 | 151 | if _id is not None: |
137 | | - # https://github.com/graphql-python/graphene/issues/124 |
138 | | - args['pk'] = node_from_global_id(connection, _id)[-1] |
| 152 | + args['pk'] = global_id_via_node(self.node_type, _id)[-1] |
139 | 153 |
|
140 | 154 | objs = objs.filter(**args) |
141 | 155 |
|
142 | | - # https://github.com/graphql-python/graphene-mongo/issues/21 |
143 | | - if after is not None: |
144 | | - _after = int(from_global_id(after)[-1]) |
145 | | - objs = objs[_after:] |
146 | | - |
147 | | - if before is not None: |
148 | | - _before = int(from_global_id(before)[-1]) |
149 | | - objs = objs[:_before] |
150 | | - |
151 | | - list_length = objs.count() |
152 | | - |
153 | | - if first is not None: |
154 | | - objs = objs[:first] |
155 | | - if last is not None: |
156 | | - # https://github.com/graphql-python/graphene-mongo/issues/20 |
157 | | - objs = objs[max(0, list_length - last):] |
158 | | - else: |
159 | | - list_length = objs.count() |
160 | | - |
161 | | - return objs, list_length |
162 | | - |
163 | | - # noqa |
164 | | - @classmethod |
165 | | - def merge_querysets(cls, default_queryset, queryset): |
166 | | - return queryset & default_queryset |
167 | | - |
168 | | - """ |
169 | | - Notes: Not sure how does this work :( |
170 | | - """ |
171 | | - @classmethod |
172 | | - def connection_resolver(cls, resolver, connection, model, root, info, **args): |
173 | | - iterable = resolver(root, info, **args) |
174 | | - |
175 | | - if iterable or iterable == []: |
176 | | - _len = len(iterable) |
177 | | - else: |
178 | | - iterable, _len = cls.get_query(model, connection, info, **args) |
179 | | - |
180 | | - if root: |
181 | | - # If we have a root, we must be at least 1 layer in, right? |
182 | | - _len = 0 |
183 | | - |
184 | 156 | connection = connection_from_list_slice( |
185 | | - iterable, |
186 | | - args, |
187 | | - slice_start=0, |
188 | | - list_length=_len, |
189 | | - list_slice_length=_len, |
190 | | - connection_type=connection, |
| 157 | + list_slice=objs, |
| 158 | + args=connection_args, |
| 159 | + list_length=objs.count(), |
| 160 | + connection_type=self.type, |
| 161 | + edge_type=self.type.Edge, |
191 | 162 | pageinfo_type=PageInfo, |
192 | | - edge_type=connection.Edge, |
193 | 163 | ) |
194 | | - connection.iterable = iterable |
195 | | - connection.length = _len |
| 164 | + connection.iterable = objs |
196 | 165 | return connection |
197 | 166 |
|
| 167 | + def chained_resolver(self, resolver, root, info, **args): |
| 168 | + resolved = resolver(root, info, **args) |
| 169 | + if resolved is not None: |
| 170 | + return resolved |
| 171 | + return self.default_resolver(root, info, **args) |
| 172 | + |
198 | 173 | def get_resolver(self, parent_resolver): |
199 | | - return partial(self.connection_resolver, parent_resolver, self.type, self.model) |
| 174 | + super_resolver = self.resolver or parent_resolver |
| 175 | + resolver = partial(self.chained_resolver, super_resolver) |
| 176 | + return partial(self.connection_resolver, resolver, self.type) |
0 commit comments