Skip to content

Commit cbd3b1b

Browse files
committed
cosmetic edits to _get_encrypted_fields()
1 parent c095f6a commit cbd3b1b

File tree

1 file changed

+24
-31
lines changed

1 file changed

+24
-31
lines changed

django_mongodb_backend/schema.py

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -477,24 +477,22 @@ def _create_collection(self, model):
477477
# Unencrypted path
478478
db.create_collection(db_table)
479479

480-
def _get_encrypted_fields(self, model, key_alt_name=None, path_prefix=None):
480+
def _get_encrypted_fields(self, model, key_alt_name_prefix=None, path_prefix=None):
481481
"""
482-
Recursively collect encryption schema data for only encrypted fields in a model.
483-
Returns None if no encrypted fields are found anywhere in the model hierarchy.
482+
Return the encrypted fields map for the given model. The "prefix"
483+
arguments are used when this method is called recursively on embedded
484+
models.
484485
"""
485486
connection = self.connection
486487
client = connection.connection
487-
fields = model._meta.fields
488-
key_alt_name = key_alt_name or model._meta.db_table
488+
key_alt_name_prefix = key_alt_name_prefix or model._meta.db_table
489489
path_prefix = path_prefix or ""
490-
491-
options = client._options
492-
auto_encryption_opts = options.auto_encryption_opts
493-
494-
key_vault_db, key_vault_coll = auto_encryption_opts._key_vault_namespace.split(".", 1)
495-
key_vault_collection = client[key_vault_db][key_vault_coll]
496-
497-
# Create partial unique index on keyAltNames
490+
auto_encryption_opts = client._options.auto_encryption_opts
491+
key_vault_db, key_vault_collection = auto_encryption_opts._key_vault_namespace.split(".", 1)
492+
key_vault_collection = client[key_vault_db][key_vault_collection]
493+
# Create partial unique index on keyAltNames.
494+
# TODO: find a better place for this. It only needs to run once for an
495+
# application's lifetime.
498496
key_vault_collection.create_index(
499497
"keyAltNames", unique=True, partialFilterExpression={"keyAltNames": {"$exists": True}}
500498
)
@@ -506,48 +504,43 @@ def _get_encrypted_fields(self, model, key_alt_name=None, path_prefix=None):
506504
else:
507505
# Otherwise, call the user-defined router.kms_provider().
508506
kms_provider = router.kms_provider(model)
509-
# Providing master_key raises an error for the local provider.
510507
master_key = connection.settings_dict.get("KMS_CREDENTIALS").get(kms_provider)
511-
client_encryption = self.connection.client_encryption
512-
508+
# Generate the encrypted fields map.
513509
field_list = []
514-
515-
for field in fields:
516-
new_key_alt_name = f"{key_alt_name}.{field.column}"
510+
for field in model._meta.fields:
511+
key_alt_name = f"{key_alt_name_prefix}.{field.column}"
517512
path = f"{path_prefix}.{field.column}" if path_prefix else field.column
518-
513+
# Check non-encrypted EmbeddedModelFields for encrypted fields.
519514
if isinstance(field, EmbeddedModelField) and not getattr(field, "encrypted", False):
520515
embedded_result = self._get_encrypted_fields(
521516
field.embedded_model,
522-
key_alt_name=new_key_alt_name,
517+
key_alt_name_prefix=key_alt_name,
523518
path_prefix=path,
524519
)
520+
# An EmbeddedModelField may not have any encrypted fields.
525521
if embedded_result:
526522
field_list.extend(embedded_result["fields"])
527523
continue
528-
524+
# Populate data for encrypted field.
529525
if getattr(field, "encrypted", False):
530-
bson_type = field.db_type(connection)
531-
data_key = key_vault_collection.find_one({"keyAltNames": new_key_alt_name})
526+
data_key = key_vault_collection.find_one({"keyAltNames": key_alt_name})
532527
if data_key:
533528
data_key = data_key["_id"]
534529
else:
535-
data_key = client_encryption.create_data_key(
530+
data_key = connection.client_encryption.create_data_key(
536531
kms_provider=kms_provider,
537-
key_alt_names=[new_key_alt_name],
532+
key_alt_names=[key_alt_name],
538533
master_key=master_key,
539534
)
540535
field_dict = {
541-
"bsonType": bson_type,
536+
"bsonType": field.db_type(connection),
542537
"path": path,
543538
"keyId": data_key,
544539
}
545-
queries = getattr(field, "queries", None)
546-
if queries:
540+
if queries := getattr(field, "queries", None):
547541
field_dict["queries"] = queries
548542
field_list.append(field_dict)
549-
550-
return {"fields": field_list} if field_list else None
543+
return {"fields": field_list}
551544

552545

553546
# GISSchemaEditor extends some SchemaEditor methods.

0 commit comments

Comments
 (0)