11from __future__ import absolute_import
22
3- import mongoengine
43from collections import OrderedDict
54from functools import partial , reduce
65
6+ import mongoengine
77from graphene .relay import ConnectionField
88from graphene .relay .connection import PageInfo
9- from graphql_relay .connection .arrayconnection import connection_from_list_slice
10- from graphql_relay .node .node import from_global_id
119from graphene .types .argument import to_arguments
1210from graphene .types .dynamic import Dynamic
13- from graphene .types .structures import Structure
11+ from graphene .types .structures import Structure , List
12+ from graphql_relay import from_global_id
13+ from graphql_relay .connection .arrayconnection import connection_from_list_slice
1414
1515from .advanced_types import PointFieldType , MultiPolygonFieldType
16- from .utils import get_model_reference_fields
16+ from .converter import convert_mongoengine_field , MongoEngineConversionError
17+ from .registry import get_global_registry
18+ from .utils import get_model_reference_fields , node_from_global_id
1719
1820
1921class MongoengineConnectionField (ConnectionField ):
@@ -43,6 +45,10 @@ def node_type(self):
4345 def model (self ):
4446 return self .node_type ._meta .model
4547
48+ @property
49+ def registry (self ):
50+ return getattr (self .node_type ._meta , 'registry' , get_global_registry ())
51+
4652 @property
4753 def args (self ):
4854 return to_arguments (
@@ -55,12 +61,19 @@ def args(self, args):
5561 self ._base_args = args
5662
5763 def _field_args (self , items ):
58- def is_filterable (v ):
59- if isinstance (v , (ConnectionField , Dynamic )):
64+ def is_filterable (k ):
65+ if not hasattr (self .model , k ):
66+ return False
67+ if isinstance (getattr (self .model , k ), property ):
6068 return False
61- # FIXME: Skip PointTypeField at this moment.
62- if not isinstance (v .type , Structure ) \
63- and isinstance (v .type (), (PointFieldType , MultiPolygonFieldType )):
69+ try :
70+ converted = convert_mongoengine_field (getattr (self .model , k ), self .registry )
71+ except MongoEngineConversionError :
72+ return False
73+ if isinstance (converted , (ConnectionField , Dynamic , List )):
74+ return False
75+ if callable (getattr (converted , 'type' , None )) and isinstance (converted .type (),
76+ (PointFieldType , MultiPolygonFieldType )):
6477 return False
6578 return True
6679
@@ -69,7 +82,7 @@ def get_type(v):
6982 return v .type .of_type ()
7083 return v .type ()
7184
72- return {k : get_type (v ) for k , v in items if is_filterable (v )}
85+ return {k : get_type (v ) for k , v in items if is_filterable (k )}
7386
7487 @property
7588 def field_args (self ):
@@ -78,19 +91,26 @@ def field_args(self):
7891 @property
7992 def reference_args (self ):
8093 def get_reference_field (r , kv ):
81- if callable (getattr (kv [1 ], 'get_type' , None )):
82- node = kv [1 ].get_type ()._type ._meta
83- if not issubclass (node .model , mongoengine .EmbeddedDocument ):
84- r .update ({kv [0 ]: node .fields ['id' ]._type .of_type ()})
94+ field = kv [1 ]
95+ mongo_field = getattr (self .model , kv [0 ], None )
96+ if isinstance (mongo_field , (mongoengine .LazyReferenceField , mongoengine .ReferenceField )):
97+ field = convert_mongoengine_field (mongo_field , self .registry )
98+ if callable (getattr (field , 'get_type' , None )):
99+ _type = field .get_type ()
100+ if _type :
101+ node = _type ._type ._meta
102+ if 'id' in node .fields and not issubclass (node .model , mongoengine .EmbeddedDocument ):
103+ r .update ({kv [0 ]: node .fields ['id' ]._type .of_type ()})
85104 return r
105+
86106 return reduce (get_reference_field , self .fields .items (), {})
87107
88108 @property
89109 def fields (self ):
90110 return self ._type ._meta .fields
91111
92112 @classmethod
93- def get_query (cls , model , info , ** args ):
113+ def get_query (cls , model , connection , info , ** args ):
94114
95115 if not callable (getattr (model , 'objects' , None )):
96116 return [], 0
@@ -102,20 +122,20 @@ def get_query(cls, model, info, **args):
102122 for arg_name , arg in args .copy ().items ():
103123 if arg_name in reference_fields :
104124 reference_model = model ._fields [arg_name ]
105- pk = from_global_id ( args .pop (arg_name ))[- 1 ]
125+ pk = node_from_global_id ( connection , args .pop (arg_name ))[- 1 ]
106126 reference_obj = reference_model .document_type_obj .objects (pk = pk ).get ()
107127 reference_args [arg_name ] = reference_obj
108128
109129 args .update (reference_args )
110130 first = args .pop ('first' , None )
111131 last = args .pop ('last' , None )
112- id = args .pop ('id' , None )
132+ _id = args .pop ('id' , None )
113133 before = args .pop ('before' , None )
114134 after = args .pop ('after' , None )
115135
116- if id is not None :
136+ if _id is not None :
117137 # https://github.com/graphql-python/graphene/issues/124
118- args ['pk' ] = from_global_id ( id )[- 1 ]
138+ args ['pk' ] = node_from_global_id ( connection , _id )[- 1 ]
119139
120140 objs = objs .filter (** args )
121141
@@ -152,14 +172,14 @@ def merge_querysets(cls, default_queryset, queryset):
152172 def connection_resolver (cls , resolver , connection , model , root , info , ** args ):
153173 iterable = resolver (root , info , ** args )
154174
155- if not iterable :
156- iterable , _len = cls .get_query (model , info , ** args )
175+ if iterable or iterable == []:
176+ _len = len (iterable )
177+ else :
178+ iterable , _len = cls .get_query (model , connection , info , ** args )
157179
158180 if root :
159181 # If we have a root, we must be at least 1 layer in, right?
160182 _len = 0
161- else :
162- _len = len (iterable )
163183
164184 connection = connection_from_list_slice (
165185 iterable ,
0 commit comments