|
3 | 3 | from django.core import checks |
4 | 4 | from django.core.exceptions import FieldDoesNotExist |
5 | 5 | from django.db import models |
| 6 | +from django.db.models import lookups |
6 | 7 | from django.db.models.fields.related import lazy_related_operation |
7 | 8 | from django.db.models.lookups import Transform |
8 | 9 |
|
9 | 10 | from .. import forms |
| 11 | +from ..query_utils import process_lhs, process_rhs |
10 | 12 | from .json import build_json_mql_path |
11 | 13 |
|
12 | 14 |
|
@@ -149,6 +151,30 @@ def formfield(self, **kwargs): |
149 | 151 | ) |
150 | 152 |
|
151 | 153 |
|
| 154 | +@EmbeddedModelField.register_lookup |
| 155 | +class EMFExact(lookups.Exact): |
| 156 | + def model_to_dict(self, instance): |
| 157 | + """Return a dict containing the data in a model instance.""" |
| 158 | + data = {} |
| 159 | + for f in instance._meta.concrete_fields: |
| 160 | + value = f.value_from_object(instance) |
| 161 | + # Unless explicitly set, primary keys aren't included in embedded |
| 162 | + # models. |
| 163 | + if f.primary_key and value is None: |
| 164 | + continue |
| 165 | + data[f"{f.name}"] = value |
| 166 | + return data |
| 167 | + |
| 168 | + def as_mql(self, compiler, connection): |
| 169 | + lhs_mql = process_lhs(self, compiler, connection) |
| 170 | + value = process_rhs(self, compiler, connection) |
| 171 | + if isinstance(value, models.Model): |
| 172 | + value = self.model_to_dict(value) |
| 173 | + prefix = self.lhs.as_mql(compiler, connection) |
| 174 | + return {"$and": [{"$eq": [f"{prefix}.{k}", v]} for k, v in value.items()]} |
| 175 | + return connection.mongo_operators[self.lookup_name](lhs_mql, value) |
| 176 | + |
| 177 | + |
152 | 178 | class KeyTransform(Transform): |
153 | 179 | def __init__(self, key_name, ref_field, *args, **kwargs): |
154 | 180 | super().__init__(*args, **kwargs) |
@@ -193,7 +219,13 @@ def preprocess_lhs(self, compiler, connection): |
193 | 219 | previous = previous.lhs |
194 | 220 | mql = previous.as_mql(compiler, connection) |
195 | 221 | # The first json_key_transform is the field name. |
196 | | - embedded_key_transforms.append(json_key_transforms.pop(0)) |
| 222 | + try: |
| 223 | + field_name = json_key_transforms.pop(0) |
| 224 | + except IndexError: |
| 225 | + # This is a lookup of the embedded model itself. |
| 226 | + pass |
| 227 | + else: |
| 228 | + embedded_key_transforms.append(field_name) |
197 | 229 | return mql, embedded_key_transforms, json_key_transforms |
198 | 230 |
|
199 | 231 | def as_mql(self, compiler, connection): |
|
0 commit comments