@@ -545,34 +545,42 @@ def _get_encrypted_fields(
545545 master_key = connection .settings_dict .get ("KMS_CREDENTIALS" , {}).get (kms_provider )
546546 client_encryption = getattr (self .connection , "client_encryption" , None )
547547
548+ def _field_dict (bson_type , path , new_key_alt_name , queries = None ):
549+ """Helper to generate a dictionary for an encrypted field.
550+ Included in parent function's scope to avoid passing parameters.
551+ """
552+ data_key = self ._get_data_key (
553+ client_encryption ,
554+ key_vault_collection ,
555+ create_data_keys ,
556+ kms_provider ,
557+ master_key ,
558+ new_key_alt_name ,
559+ )
560+ field_dict = {
561+ "bsonType" : bson_type ,
562+ "path" : path ,
563+ "keyId" : data_key ,
564+ }
565+ if queries :
566+ field_dict ["queries" ] = queries
567+ return field_dict
568+
548569 field_list = []
549570
550571 for field in fields :
551572 new_key_alt_name = f"{ key_alt_name } .{ field .column } "
552573 path = f"{ path_prefix } .{ field .column } " if path_prefix else field .column
553574
554- # --- Embedded Single Document ---
555- if isinstance (field , EmbeddedModelField ):
575+ if isinstance (field , (EmbeddedModelField , EmbeddedModelArrayField )):
556576 if getattr (field , "encrypted" , False ):
557- # Entire embedded object encrypted
558- data_key = self ._get_data_key (
559- client_encryption ,
560- key_vault_collection ,
561- create_data_keys ,
562- kms_provider ,
563- master_key ,
564- new_key_alt_name ,
577+ bson_type = "object" if isinstance (field , EmbeddedModelField ) else "array"
578+ field_list .append (
579+ _field_dict (
580+ bson_type , path , new_key_alt_name , getattr (field , "queries" , None )
581+ )
565582 )
566- field_dict = {
567- "bsonType" : "object" ,
568- "path" : path ,
569- "keyId" : data_key ,
570- }
571- if getattr (field , "queries" , False ):
572- field_dict ["queries" ] = field .queries
573- field_list .append (field_dict )
574583 else :
575- # Recurse into embedded model
576584 embedded_result = self ._get_encrypted_fields (
577585 field .embedded_model ,
578586 create_data_keys = create_data_keys ,
@@ -581,58 +589,11 @@ def _get_encrypted_fields(
581589 )
582590 if embedded_result and embedded_result .get ("fields" ):
583591 field_list .extend (embedded_result ["fields" ])
584- continue
585-
586- # --- Array of Embedded Documents ---
587- if isinstance (field , EmbeddedModelArrayField ):
588- if getattr (field , "encrypted" , False ):
589- # Entire array contents encrypted - flat entry
590- data_key = self ._get_data_key (
591- client_encryption ,
592- key_vault_collection ,
593- create_data_keys ,
594- kms_provider ,
595- master_key ,
596- new_key_alt_name ,
597- )
598- field_dict = {
599- "bsonType" : "array" ,
600- "path" : path ,
601- "keyId" : data_key ,
602- }
603- if getattr (field , "queries" , False ):
604- field_dict ["queries" ] = field .queries
605- field_list .append (field_dict )
606- else :
607- # Recurse into embedded model for fields inside array elements
608- embedded_result = self ._get_encrypted_fields (
609- field .embedded_model ,
610- create_data_keys = create_data_keys ,
611- key_alt_name = new_key_alt_name ,
612- path_prefix = path , # array prefix in path
613- )
614- if embedded_result and embedded_result .get ("fields" ):
615- field_list .extend (embedded_result ["fields" ])
616- continue
617-
618- # --- Leaf encrypted field ---
619- if getattr (field , "encrypted" , False ):
620- data_key = self ._get_data_key (
621- client_encryption ,
622- key_vault_collection ,
623- create_data_keys ,
624- kms_provider ,
625- master_key ,
626- new_key_alt_name ,
592+ elif getattr (field , "encrypted" , False ):
593+ bson_type = field .db_type (connection )
594+ field_list .append (
595+ _field_dict (bson_type , path , new_key_alt_name , getattr (field , "queries" , None ))
627596 )
628- field_dict = {
629- "bsonType" : field .db_type (connection ),
630- "path" : path ,
631- "keyId" : data_key ,
632- }
633- if getattr (field , "queries" , False ):
634- field_dict ["queries" ] = field .queries
635- field_list .append (field_dict )
636597
637598 return {"fields" : field_list } if field_list else None
638599
0 commit comments