@@ -483,52 +483,111 @@ def _create_collection(self, model):
483483 else :
484484 db .create_collection (db_table )
485485
486- def _get_encrypted_fields (self , model , client , create_data_keys = False ):
486+ def _get_encrypted_fields (
487+ self , model , client , create_data_keys = False , key_alt_name = None , client_encryption = None
488+ ):
489+ """
490+ Recursively collect encryption schema data for fields in a model.
491+
492+ key_alt_name is the base path for this level, typically model._meta.db_table
493+ """
487494 connection = self .connection
488495 fields = model ._meta .fields
489- options = client ._options
490- auto_encryption_opts = options .auto_encryption_opts
491- kms_provider = router .kms_provider (model )
492- master_key = self .connection .settings_dict .get ("KMS_CREDENTIALS" , {}).get (kms_provider )
493- client_encryption = ClientEncryption (
494- auto_encryption_opts ._kms_providers ,
495- auto_encryption_opts ._key_vault_namespace ,
496- client ,
497- client .codec_options ,
498- )
499- key_vault_db , key_vault_coll = auto_encryption_opts ._key_vault_namespace .split ("." , 1 )
500- key_vault_collection = client [key_vault_db ][key_vault_coll ]
501- db_table = model ._meta .db_table
496+ key_alt_name = key_alt_name or model ._meta .db_table
497+
498+ # Initialize ClientEncryption once
499+ if client_encryption is None :
500+ options = client ._options
501+ auto_encryption_opts = options .auto_encryption_opts
502+ kms_provider = router .kms_provider (model )
503+ master_key = connection .settings_dict .get ("KMS_CREDENTIALS" , {}).get (kms_provider )
504+ client_encryption = ClientEncryption (
505+ auto_encryption_opts ._kms_providers ,
506+ auto_encryption_opts ._key_vault_namespace ,
507+ client ,
508+ client .codec_options ,
509+ )
510+ key_vault_db , key_vault_coll = auto_encryption_opts ._key_vault_namespace .split ("." , 1 )
511+ key_vault_collection = client [key_vault_db ][key_vault_coll ]
512+ else :
513+ auto_encryption_opts = client ._options .auto_encryption_opts
514+ key_vault_db , key_vault_coll = auto_encryption_opts ._key_vault_namespace .split ("." , 1 )
515+ key_vault_collection = client [key_vault_db ][key_vault_coll ]
516+ kms_provider = router .kms_provider (model )
517+ master_key = connection .settings_dict .get ("KMS_CREDENTIALS" , {}).get (kms_provider )
518+
502519 field_list = []
520+
503521 for field in fields :
522+ new_path = f"{ key_alt_name } .{ field .column } "
523+
524+ # --- EmbeddedModelField case ---
504525 if isinstance (field , EmbeddedModelField ):
505- # Recursively get encrypted fields for the embedded model.
506- self ._get_encrypted_fields (field .embedded_model , client , create_data_keys )
526+ field_dict = {"bsonType" : "object" , "path" : field .column }
527+
528+ if getattr (field , "encrypted" , False ):
529+ if create_data_keys :
530+ data_key = client_encryption .create_data_key (
531+ kms_provider = kms_provider ,
532+ master_key = master_key ,
533+ key_alt_names = [new_path ],
534+ )
535+ else :
536+ key_doc = key_vault_collection .find_one ({"keyAltNames" : new_path })
537+ if not key_doc :
538+ raise ValueError (
539+ f"No key found in keyvault for keyAltName={ new_path } . "
540+ "Run with '--create-data-keys' to create missing keys."
541+ )
542+ data_key = key_doc ["_id" ]
543+
544+ field_dict ["keyId" ] = data_key
545+
546+ if getattr (field , "queries" , False ):
547+ field_dict ["queries" ] = field .queries
548+
549+ field_list .append (field_dict )
550+ continue
551+
552+ # Not encrypting whole object — add object entry and recurse
553+ field_list .append (field_dict )
554+ embedded_result = self ._get_encrypted_fields (
555+ field .embedded_model ,
556+ client ,
557+ create_data_keys = create_data_keys ,
558+ key_alt_name = new_path ,
559+ client_encryption = client_encryption ,
560+ )
561+ field_list .extend (embedded_result ["fields" ])
562+ continue
563+
564+ # --- Leaf encrypted field case ---
507565 if getattr (field , "encrypted" , False ):
508- key_alt_name = f"{ db_table } .{ field .column } "
509566 if create_data_keys :
510567 data_key = client_encryption .create_data_key (
511568 kms_provider = kms_provider ,
512569 master_key = master_key ,
513- key_alt_names = [key_alt_name ],
570+ key_alt_names = [new_path ], # distinct per field
514571 )
515572 else :
516- key_doc = key_vault_collection .find_one ({"keyAltNames" : key_alt_name })
573+ key_doc = key_vault_collection .find_one ({"keyAltNames" : new_path })
517574 if not key_doc :
518575 raise ValueError (
519- f"No key found in keyvault for keyAltName={ key_alt_name } . "
520- "You may need to run the management command with "
521- "'--create-data-keys' to create missing keys."
576+ f"No key found in keyvault for keyAltName={ new_path } . "
577+ "Run with '--create-data-keys' to create missing keys."
522578 )
523579 data_key = key_doc ["_id" ]
580+
524581 field_dict = {
525582 "bsonType" : field .db_type (connection ),
526583 "path" : field .column ,
527584 "keyId" : data_key ,
528585 }
529586 if getattr (field , "queries" , False ):
530587 field_dict ["queries" ] = field .queries
588+
531589 field_list .append (field_dict )
590+
532591 return {"fields" : field_list }
533592
534593
0 commit comments