@@ -132,7 +132,7 @@ def get_transform(self, name):
132132 if transform :
133133 return transform
134134 field = self .embedded_model ._meta .get_field (name )
135- return KeyTransformFactory ( name , field )
135+ return EmbeddedModelTransformFactory ( field )
136136
137137 def validate (self , value , model_instance ):
138138 super ().validate (value , model_instance )
@@ -156,39 +156,40 @@ def formfield(self, **kwargs):
156156 )
157157
158158
159- class KeyTransform (Transform ):
160- def __init__ (self , key_name , ref_field , * args , ** kwargs ):
159+ class EmbeddedModelTransform (Transform ):
160+ def __init__ (self , field , * args , ** kwargs ):
161161 super ().__init__ (* args , ** kwargs )
162- self .key_name = str (key_name )
163- self .ref_field = ref_field
162+ # self.field aliases self._field via BaseExpression.field returning
163+ # self.output_field.
164+ self ._field = field
164165
165166 def get_lookup (self , name ):
166- return self .ref_field .get_lookup (name )
167+ return self .field .get_lookup (name )
167168
168169 def get_transform (self , name ):
169170 """
170171 Validate that `name` is either a field of an embedded model or a
171172 lookup on an embedded model's field.
172173 """
173- if transform := self .ref_field .get_transform (name ):
174+ if transform := self .field .get_transform (name ):
174175 return transform
175- suggested_lookups = difflib .get_close_matches (name , self .ref_field .get_lookups ())
176+ suggested_lookups = difflib .get_close_matches (name , self .field .get_lookups ())
176177 if suggested_lookups :
177178 suggested_lookups = " or " .join (suggested_lookups )
178179 suggestion = f", perhaps you meant { suggested_lookups } ?"
179180 else :
180181 suggestion = "."
181182 raise FieldDoesNotExist (
182183 f"Unsupported lookup '{ name } ' for "
183- f"{ self .ref_field .__class__ .__name__ } '{ self .ref_field .name } '"
184+ f"{ self .field .__class__ .__name__ } '{ self .field .name } '"
184185 f"{ suggestion } "
185186 )
186187
187188 def as_mql (self , compiler , connection , as_path = False ):
188189 previous = self
189190 columns = []
190- while isinstance (previous , KeyTransform ):
191- columns .insert (0 , previous .ref_field .column )
191+ while isinstance (previous , EmbeddedModelTransform ):
192+ columns .insert (0 , previous .field .column )
192193 previous = previous .lhs
193194 if as_path :
194195 mql = previous .as_mql (compiler , connection , as_path = True )
@@ -201,13 +202,12 @@ def as_mql(self, compiler, connection, as_path=False):
201202
202203 @property
203204 def output_field (self ):
204- return self .ref_field
205+ return self ._field
205206
206207
207- class KeyTransformFactory :
208- def __init__ (self , key_name , ref_field ):
209- self .key_name = key_name
210- self .ref_field = ref_field
208+ class EmbeddedModelTransformFactory :
209+ def __init__ (self , field ):
210+ self .field = field
211211
212212 def __call__ (self , * args , ** kwargs ):
213- return KeyTransform (self .key_name , self . ref_field , * args , ** kwargs )
213+ return EmbeddedModelTransform (self .field , * args , ** kwargs )
0 commit comments