11import datetime
22from decimal import Decimal
33
4+ import pymongo
5+ from bson .binary import Binary
6+ from django .conf import settings
7+ from django .db import connections
8+ from django .db .models import Model
9+
410from django_mongodb_backend .fields import EncryptedCharField
511
612from .models import (
3238from .test_base import EncryptionTestCase
3339
3440
35- class EncryptedEmbeddedModelTests (EncryptionTestCase ):
41+ class EncryptedFieldTests (EncryptionTestCase ):
42+ def assertEncrypted (self , model_or_instance , field_name ):
43+ """
44+ Check if the field value in the database is stored as Binary.
45+ Works with either a Django model instance or a model class.
46+ """
47+
48+ conn_params = connections ["encrypted" ].get_connection_params ()
49+ db_name = settings .DATABASES ["encrypted" ]["NAME" ]
50+
51+ if conn_params .pop ("auto_encryption_opts" , False ):
52+ with pymongo .MongoClient (** conn_params ) as new_connection :
53+ if hasattr (model_or_instance , "_meta" ):
54+ collection_name = model_or_instance ._meta .db_table
55+ else :
56+ self .fail (f"Object { model_or_instance !r} is not a Django model or instance" )
57+
58+ collection = new_connection [db_name ][collection_name ]
59+
60+ # If it's an instance of a Django model, narrow to that _id
61+ if isinstance (model_or_instance , Model ):
62+ docs = collection .find (
63+ {"_id" : model_or_instance .pk , field_name : {"$exists" : True }}
64+ )
65+ else :
66+ # Otherwise it's a model class
67+ docs = collection .find ({field_name : {"$exists" : True }})
68+
69+ found = False
70+ for doc in docs :
71+ found = True
72+ value = doc .get (field_name )
73+ self .assertTrue (
74+ isinstance (value , Binary ),
75+ msg = f"Field '{ field_name } ' in document { doc ['_id' ]} is "
76+ "not encrypted (type={type(value)})" ,
77+ )
78+
79+ self .assertTrue (
80+ found ,
81+ msg = f"No documents with field '{ field_name } ' found in '{{collection_name}}'" ,
82+ )
83+
84+ else :
85+ self .fail ("auto_encryption_opts is not configured; encryption not enabled." )
86+
87+
88+ class EncryptedEmbeddedModelTests (EncryptedFieldTests ):
3689 def setUp (self ):
3790 self .billing = Billing (cc_type = "Visa" , cc_number = "4111111111111111" )
3891 self .patient_record = PatientRecord (ssn = "123-45-6789" , billing = self .billing )
3992 self .patient = Patient .objects .create (
4093 patient_name = "John Doe" , patient_id = 123456789 , patient_record = self .patient_record
4194 )
4295
43- def test_patient (self ):
96+ def test_object (self ):
4497 patient = Patient .objects .get (id = self .patient .id )
4598 self .assertEqual (patient .patient_record .ssn , "123-45-6789" )
4699 self .assertEqual (patient .patient_record .billing .cc_type , "Visa" )
47100 self .assertEqual (patient .patient_record .billing .cc_number , "4111111111111111" )
48101
49102
50- class EncryptedEmbeddedModelArrayTests (EncryptionTestCase ):
103+ class EncryptedEmbeddedModelArrayTests (EncryptedFieldTests ):
51104 def setUp (self ):
52105 self .actor1 = Actor (name = "Actor One" )
53106 self .actor2 = Actor (name = "Actor Two" )
@@ -56,13 +109,14 @@ def setUp(self):
56109 cast = [self .actor1 , self .actor2 ],
57110 )
58111
59- def test_movie_actors (self ):
112+ def test_array (self ):
60113 self .assertEqual (len (self .movie .cast ), 2 )
61114 self .assertEqual (self .movie .cast [0 ].name , "Actor One" )
62115 self .assertEqual (self .movie .cast [1 ].name , "Actor Two" )
116+ self .assertEncrypted (self .movie , "cast" )
63117
64118
65- class EncryptedFieldTests (EncryptionTestCase ):
119+ class EncryptedFieldTests (EncryptedFieldTests ):
66120 def assertEquality (self , model_cls , val ):
67121 model_cls .objects .create (value = val )
68122 fetched = model_cls .objects .get (value = val )
@@ -80,28 +134,36 @@ def assertRange(self, model_cls, *, low, high, threshold):
80134 # Equality-only fields
81135 def test_binary (self ):
82136 self .assertEquality (BinaryModel , b"\x00 \x01 \x02 " )
137+ self .assertEncrypted (BinaryModel , "value" )
83138
84139 def test_boolean (self ):
85140 self .assertEquality (BooleanModel , True )
141+ self .assertEncrypted (BooleanModel , "value" )
86142
87143 def test_char (self ):
88144 self .assertEquality (CharModel , "hello" )
145+ self .assertEncrypted (CharModel , "value" )
89146
90147 def test_email (self ):
91148 self .assertEquality (EmailModel , "test@example.com" )
149+ self .assertEncrypted (EmailModel , "value" )
92150
93151 def test_ip (self ):
94152 self .assertEquality (GenericIPAddressModel , "192.168.0.1" )
153+ self .assertEncrypted (GenericIPAddressModel , "value" )
95154
96155 def test_text (self ):
97156 self .assertEquality (TextModel , "some text" )
157+ self .assertEncrypted (TextModel , "value" )
98158
99159 def test_url (self ):
100160 self .assertEquality (URLModel , "https://example.com" )
161+ self .assertEncrypted (URLModel , "value" )
101162
102163 # Range fields
103164 def test_big_integer (self ):
104165 self .assertRange (BigIntegerModel , low = 100 , high = 200 , threshold = 150 )
166+ self .assertEncrypted (BigIntegerModel , "value" )
105167
106168 def test_date (self ):
107169 self .assertRange (
@@ -110,6 +172,7 @@ def test_date(self):
110172 high = datetime .date (2024 , 6 , 10 ),
111173 threshold = datetime .date (2024 , 6 , 5 ),
112174 )
175+ self .assertEncrypted (DateModel , "value" )
113176
114177 def test_datetime (self ):
115178 self .assertRange (
@@ -118,6 +181,7 @@ def test_datetime(self):
118181 high = datetime .datetime (2024 , 6 , 2 , 12 , 0 ),
119182 threshold = datetime .datetime (2024 , 6 , 2 , 0 , 0 ),
120183 )
184+ self .assertEncrypted (DateTimeModel , "value" )
121185
122186 def test_decimal (self ):
123187 self .assertRange (
@@ -126,6 +190,7 @@ def test_decimal(self):
126190 high = Decimal ("200.50" ),
127191 threshold = Decimal ("150" ),
128192 )
193+ self .assertEncrypted (DecimalModel , "value" )
129194
130195 def test_duration (self ):
131196 self .assertRange (
@@ -134,24 +199,31 @@ def test_duration(self):
134199 high = datetime .timedelta (days = 10 ),
135200 threshold = datetime .timedelta (days = 5 ),
136201 )
202+ self .assertEncrypted (DurationModel , "value" )
137203
138204 def test_float (self ):
139205 self .assertRange (FloatModel , low = 1.23 , high = 4.56 , threshold = 3.0 )
206+ self .assertEncrypted (FloatModel , "value" )
140207
141208 def test_integer (self ):
142209 self .assertRange (IntegerModel , low = 5 , high = 10 , threshold = 7 )
210+ self .assertEncrypted (IntegerModel , "value" )
143211
144212 def test_positive_big_integer (self ):
145213 self .assertRange (PositiveBigIntegerModel , low = 100 , high = 500 , threshold = 200 )
214+ self .assertEncrypted (PositiveBigIntegerModel , "value" )
146215
147216 def test_positive_integer (self ):
148217 self .assertRange (PositiveIntegerModel , low = 10 , high = 20 , threshold = 15 )
218+ self .assertEncrypted (PositiveIntegerModel , "value" )
149219
150220 def test_positive_small_integer (self ):
151221 self .assertRange (PositiveSmallIntegerModel , low = 5 , high = 8 , threshold = 6 )
222+ self .assertEncrypted (PositiveSmallIntegerModel , "value" )
152223
153224 def test_small_integer (self ):
154225 self .assertRange (SmallIntegerModel , low = - 5 , high = 2 , threshold = 0 )
226+ self .assertEncrypted (SmallIntegerModel , "value" )
155227
156228 def test_time (self ):
157229 self .assertRange (
@@ -160,9 +232,10 @@ def test_time(self):
160232 high = datetime .time (15 , 0 ),
161233 threshold = datetime .time (12 , 0 ),
162234 )
235+ self .assertEncrypted (TimeModel , "value" )
163236
164237
165- class EncryptedFieldMixinTests (EncryptionTestCase ):
238+ class EncryptedFieldMixinTests (EncryptedFieldTests ):
166239 def test_null_true_raises_error (self ):
167240 with self .assertRaisesMessage (
168241 ValueError , "'null=True' is not supported for encrypted fields."
0 commit comments