Skip to content

Commit 8656b2c

Browse files
committed
Code review and other fixes
- Add patient schema to management and schema tests - Use results from recursion - Reorder patient models and add unencrypted bill amount field - Specify databases for management and schema tests
1 parent 9a4fb42 commit 8656b2c

File tree

4 files changed

+128
-54
lines changed

4 files changed

+128
-54
lines changed

django_mongodb_backend/schema.py

Lines changed: 81 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -483,52 +483,111 @@ def _create_collection(self, model):
483483
else:
484484
db.create_collection(db_table)
485485

486-
def _get_encrypted_fields(self, model, client, create_data_keys=False):
486+
def _get_encrypted_fields(
487+
self, model, client, create_data_keys=False, key_alt_name=None, client_encryption=None
488+
):
489+
"""
490+
Recursively collect encryption schema data for fields in a model.
491+
492+
key_alt_name is the base path for this level, typically model._meta.db_table
493+
"""
487494
connection = self.connection
488495
fields = model._meta.fields
489-
options = client._options
490-
auto_encryption_opts = options.auto_encryption_opts
491-
kms_provider = router.kms_provider(model)
492-
master_key = self.connection.settings_dict.get("KMS_CREDENTIALS", {}).get(kms_provider)
493-
client_encryption = ClientEncryption(
494-
auto_encryption_opts._kms_providers,
495-
auto_encryption_opts._key_vault_namespace,
496-
client,
497-
client.codec_options,
498-
)
499-
key_vault_db, key_vault_coll = auto_encryption_opts._key_vault_namespace.split(".", 1)
500-
key_vault_collection = client[key_vault_db][key_vault_coll]
501-
db_table = model._meta.db_table
496+
key_alt_name = key_alt_name or model._meta.db_table
497+
498+
# Initialize ClientEncryption once
499+
if client_encryption is None:
500+
options = client._options
501+
auto_encryption_opts = options.auto_encryption_opts
502+
kms_provider = router.kms_provider(model)
503+
master_key = connection.settings_dict.get("KMS_CREDENTIALS", {}).get(kms_provider)
504+
client_encryption = ClientEncryption(
505+
auto_encryption_opts._kms_providers,
506+
auto_encryption_opts._key_vault_namespace,
507+
client,
508+
client.codec_options,
509+
)
510+
key_vault_db, key_vault_coll = auto_encryption_opts._key_vault_namespace.split(".", 1)
511+
key_vault_collection = client[key_vault_db][key_vault_coll]
512+
else:
513+
auto_encryption_opts = client._options.auto_encryption_opts
514+
key_vault_db, key_vault_coll = auto_encryption_opts._key_vault_namespace.split(".", 1)
515+
key_vault_collection = client[key_vault_db][key_vault_coll]
516+
kms_provider = router.kms_provider(model)
517+
master_key = connection.settings_dict.get("KMS_CREDENTIALS", {}).get(kms_provider)
518+
502519
field_list = []
520+
503521
for field in fields:
522+
new_path = f"{key_alt_name}.{field.column}"
523+
524+
# --- EmbeddedModelField case ---
504525
if isinstance(field, EmbeddedModelField):
505-
# Recursively get encrypted fields for the embedded model.
506-
self._get_encrypted_fields(field.embedded_model, client, create_data_keys)
526+
field_dict = {"bsonType": "object", "path": field.column}
527+
528+
if getattr(field, "encrypted", False):
529+
if create_data_keys:
530+
data_key = client_encryption.create_data_key(
531+
kms_provider=kms_provider,
532+
master_key=master_key,
533+
key_alt_names=[new_path],
534+
)
535+
else:
536+
key_doc = key_vault_collection.find_one({"keyAltNames": new_path})
537+
if not key_doc:
538+
raise ValueError(
539+
f"No key found in keyvault for keyAltName={new_path}. "
540+
"Run with '--create-data-keys' to create missing keys."
541+
)
542+
data_key = key_doc["_id"]
543+
544+
field_dict["keyId"] = data_key
545+
546+
if getattr(field, "queries", False):
547+
field_dict["queries"] = field.queries
548+
549+
field_list.append(field_dict)
550+
continue
551+
552+
# Not encrypting whole object — add object entry and recurse
553+
field_list.append(field_dict)
554+
embedded_result = self._get_encrypted_fields(
555+
field.embedded_model,
556+
client,
557+
create_data_keys=create_data_keys,
558+
key_alt_name=new_path,
559+
client_encryption=client_encryption,
560+
)
561+
field_list.extend(embedded_result["fields"])
562+
continue
563+
564+
# --- Leaf encrypted field case ---
507565
if getattr(field, "encrypted", False):
508-
key_alt_name = f"{db_table}.{field.column}"
509566
if create_data_keys:
510567
data_key = client_encryption.create_data_key(
511568
kms_provider=kms_provider,
512569
master_key=master_key,
513-
key_alt_names=[key_alt_name],
570+
key_alt_names=[new_path], # distinct per field
514571
)
515572
else:
516-
key_doc = key_vault_collection.find_one({"keyAltNames": key_alt_name})
573+
key_doc = key_vault_collection.find_one({"keyAltNames": new_path})
517574
if not key_doc:
518575
raise ValueError(
519-
f"No key found in keyvault for keyAltName={key_alt_name}. "
520-
"You may need to run the management command with "
521-
"'--create-data-keys' to create missing keys."
576+
f"No key found in keyvault for keyAltName={new_path}. "
577+
"Run with '--create-data-keys' to create missing keys."
522578
)
523579
data_key = key_doc["_id"]
580+
524581
field_dict = {
525582
"bsonType": field.db_type(connection),
526583
"path": field.column,
527584
"keyId": data_key,
528585
}
529586
if getattr(field, "queries", False):
530587
field_dict["queries"] = field.queries
588+
531589
field_list.append(field_dict)
590+
532591
return {"fields": field_list}
533592

534593

tests/encryption_/models.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,27 @@
2626
from django_mongodb_backend.models import EmbeddedModel
2727

2828

29-
class Billing(EmbeddedModel):
30-
cc_type = models.CharField(max_length=50)
31-
cc_number = models.CharField(max_length=20)
32-
33-
34-
class PatientRecord(EmbeddedModel):
35-
ssn = EncryptedCharField(max_length=11, queries={"queryType": "equality"})
36-
billing = EncryptedEmbeddedModelField(Billing)
37-
38-
3929
class Patient(models.Model):
4030
patient_name = models.CharField(max_length=255)
4131
patient_id = models.BigIntegerField()
42-
patient_record = EmbeddedModelField(PatientRecord)
32+
patient_record = EmbeddedModelField("PatientRecord")
4333

4434
def __str__(self):
4535
return f"{self.patient_name} ({self.patient_id})"
4636

4737

48-
class EncryptedModel(models.Model):
38+
class PatientRecord(EmbeddedModel):
39+
ssn = EncryptedCharField(max_length=11, queries={"queryType": "equality"})
40+
billing = EncryptedEmbeddedModelField("Billing")
41+
bill_amount = models.DecimalField(max_digits=10, decimal_places=2)
42+
43+
44+
class Billing(EmbeddedModel):
45+
cc_type = models.CharField(max_length=50)
46+
cc_number = models.CharField(max_length=20)
47+
48+
49+
class EncryptedModelBase(models.Model):
4950
"""
5051
Abstract base model for all Encrypted models
5152
that require the 'supports_queryable_encryption' DB feature.
@@ -57,78 +58,78 @@ class Meta:
5758

5859

5960
# Equality-queryable fields
60-
class EncryptedBinaryTest(EncryptedModel):
61+
class EncryptedBinaryTest(EncryptedModelBase):
6162
value = EncryptedBinaryField(queries={"queryType": "equality"})
6263

6364

64-
class EncryptedBooleanTest(EncryptedModel):
65+
class EncryptedBooleanTest(EncryptedModelBase):
6566
value = EncryptedBooleanField(queries={"queryType": "equality"})
6667

6768

68-
class EncryptedCharTest(EncryptedModel):
69+
class EncryptedCharTest(EncryptedModelBase):
6970
value = EncryptedCharField(max_length=255, queries={"queryType": "equality"})
7071

7172

72-
class EncryptedEmailTest(EncryptedModel):
73+
class EncryptedEmailTest(EncryptedModelBase):
7374
value = EncryptedEmailField(max_length=255, queries={"queryType": "equality"})
7475

7576

76-
class EncryptedGenericIPAddressTest(EncryptedModel):
77+
class EncryptedGenericIPAddressTest(EncryptedModelBase):
7778
value = EncryptedGenericIPAddressField(queries={"queryType": "equality"})
7879

7980

80-
class EncryptedTextTest(EncryptedModel):
81+
class EncryptedTextTest(EncryptedModelBase):
8182
value = EncryptedTextField(queries={"queryType": "equality"})
8283

8384

84-
class EncryptedURLTest(EncryptedModel):
85+
class EncryptedURLTest(EncryptedModelBase):
8586
value = EncryptedURLField(max_length=500, queries={"queryType": "equality"})
8687

8788

8889
# Range-queryable fields (also support equality)
89-
class EncryptedBigIntegerTest(EncryptedModel):
90+
class EncryptedBigIntegerTest(EncryptedModelBase):
9091
value = EncryptedBigIntegerField(queries={"queryType": "range"})
9192

9293

93-
class EncryptedDateTest(EncryptedModel):
94+
class EncryptedDateTest(EncryptedModelBase):
9495
value = EncryptedDateField(queries={"queryType": "range"})
9596

9697

97-
class EncryptedDateTimeTest(EncryptedModel):
98+
class EncryptedDateTimeTest(EncryptedModelBase):
9899
value = EncryptedDateTimeField(queries={"queryType": "range"})
99100

100101

101-
class EncryptedDecimalTest(EncryptedModel):
102+
class EncryptedDecimalTest(EncryptedModelBase):
102103
value = EncryptedDecimalField(max_digits=10, decimal_places=2, queries={"queryType": "range"})
103104

104105

105-
class EncryptedDurationTest(EncryptedModel):
106+
class EncryptedDurationTest(EncryptedModelBase):
106107
value = EncryptedDurationField(queries={"queryType": "range"})
107108

108109

109-
class EncryptedFloatTest(EncryptedModel):
110+
class EncryptedFloatTest(EncryptedModelBase):
110111
value = EncryptedFloatField(queries={"queryType": "range"})
111112

112113

113-
class EncryptedIntegerTest(EncryptedModel):
114+
class EncryptedIntegerTest(EncryptedModelBase):
114115
value = EncryptedIntegerField(queries={"queryType": "range"})
115116

116117

117-
class EncryptedPositiveBigIntegerTest(EncryptedModel):
118+
class EncryptedPositiveBigIntegerTest(EncryptedModelBase):
118119
value = EncryptedPositiveBigIntegerField(queries={"queryType": "range"})
119120

120121

121-
class EncryptedPositiveIntegerTest(EncryptedModel):
122+
class EncryptedPositiveIntegerTest(EncryptedModelBase):
122123
value = EncryptedPositiveIntegerField(queries={"queryType": "range"})
123124

124125

125-
class EncryptedPositiveSmallIntegerTest(EncryptedModel):
126+
class EncryptedPositiveSmallIntegerTest(EncryptedModelBase):
126127
value = EncryptedPositiveSmallIntegerField(queries={"queryType": "range"})
127128

128129

129-
class EncryptedSmallIntegerTest(EncryptedModel):
130+
class EncryptedSmallIntegerTest(EncryptedModelBase):
130131
value = EncryptedSmallIntegerField(queries={"queryType": "range"})
131132

132133

133-
class EncryptedTimeTest(EncryptedModel):
134+
class EncryptedTimeTest(EncryptedModelBase):
134135
value = EncryptedTimeField(queries={"queryType": "range"})

tests/encryption_/test_management.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,17 @@
88
@skipUnlessDBFeature("supports_queryable_encryption")
99
@modify_settings(INSTALLED_APPS={"prepend": "django_mongodb_backend"})
1010
class CommandTests(TestCase):
11+
databases = {"default", "encrypted"}
1112
maxDiff = None
1213

1314
# Expected encrypted field maps for all Encrypted* models
1415
expected_maps = {
16+
"encryption__patientrecord": {
17+
"fields": [
18+
{"bsonType": "string", "path": "ssn", "queries": {"queryType": "equality"}},
19+
{"bsonType": "object", "path": "billing"},
20+
]
21+
},
1522
# Equality-queryable fields
1623
"encryption__encryptedbinarytest": {
1724
"fields": [

tests/encryption_/test_schema.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
from django.db import connections
22
from django.test import TestCase
33

4-
from . import models # your encryption_ models file with Encrypted*Test classes
4+
from . import models
55

66

77
class SchemaTests(TestCase):
8+
databases = {"default", "encrypted"}
89
maxDiff = None
910

1011
# Expected encrypted fields map per model
1112
expected_map = {
13+
"PatientRecord": {
14+
"fields": [
15+
{"bsonType": "string", "path": "ssn", "queries": {"queryType": "equality"}},
16+
{"bsonType": "object", "path": "billing"},
17+
]
18+
},
1219
"EncryptedBinaryTest": {
1320
"fields": [
1421
{"bsonType": "binData", "path": "value", "queries": {"queryType": "equality"}}

0 commit comments

Comments
 (0)