1111from graphene .types .argument import to_arguments
1212
1313
14+ from .utils import get_model_reference_fields
15+
16+
1417# noqa
1518class MongoengineListField (Field ):
1619
@@ -60,26 +63,35 @@ def model(self):
6063 @property
6164 def args (self ):
6265 return to_arguments (
63- self ._base_args or OrderedDict (), self .default_filter_args
66+ self ._base_args or OrderedDict (),
67+ dict (self .field_args , ** self .reference_args )
6468 )
6569
6670 @args .setter
6771 def args (self , args ):
6872 self ._base_args = args
6973
7074 @property
71- def default_filter_args (self ):
75+ def field_args (self ):
7276 def is_filterable (kv ):
7377 return hasattr (kv [1 ], '_type' ) \
7478 and callable (getattr (kv [1 ]._type , '_of_type' , None ))
7579
7680 return reduce (
7781 lambda r , kv : r .update (
7882 {kv [0 ]: kv [1 ]._type ._of_type ()}) or r if is_filterable (kv ) else r ,
79- self .fields .items (),
80- {}
83+ self .fields .items (), {}
8184 )
8285
86+ @property
87+ def reference_args (self ):
88+ def get_reference_field (r , kv ):
89+ if callable (getattr (kv [1 ], 'get_type' , None )):
90+ node = kv [1 ].get_type ()._type ._meta
91+ r .update ({kv [0 ]: node .fields ['id' ]._type .of_type ()})
92+ return r
93+ return reduce (get_reference_field , self .fields .items (), {})
94+
8395 @property
8496 def filter_fields (self ):
8597 return self ._type ._meta .filter_fields
@@ -95,8 +107,17 @@ def get_query(cls, model, info, **args):
95107 return []
96108
97109 objs = model .objects ()
98-
99110 if args :
111+ reference_fields = get_model_reference_fields (model )
112+ reference_args = {}
113+ for arg_name , arg in args .copy ().items ():
114+ if arg_name in reference_fields :
115+ reference_model = model ._fields [arg_name ]
116+ pk = from_global_id (args .pop (arg_name ))[- 1 ]
117+ reference_obj = reference_model .document_type_obj .objects (pk = pk ).get ()
118+ reference_args [arg_name ] = reference_obj
119+
120+ args .update (reference_args )
100121 first = args .pop ('first' , None )
101122 last = args .pop ('last' , None )
102123 id = args .pop ('id' , None )
@@ -121,7 +142,7 @@ def get_query(cls, model, info, **args):
121142 if first is not None :
122143 objs = objs [:first ]
123144 if last is not None :
124- # fix for https://github.com/graphql-python/graphene-mongo/issues/20
145+ # https://github.com/graphql-python/graphene-mongo/issues/20
125146 objs = objs [- (last + 1 ):]
126147
127148 return objs
0 commit comments