Skip to content

Commit 1206050

Browse files
authored
DGS-22734 Add Accept-Version header (#2117)
* DGS-22734 Add Accept-Version header * Add aws fix * Fix Avro bytes serialization
1 parent f37ec78 commit 1206050

File tree

8 files changed

+121
-29
lines changed

8 files changed

+121
-29
lines changed

src/confluent_kafka/schema_registry/_async/avro.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,16 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731
348348
parsed_schema = self._parsed_schema
349349

350350
with _ContextStringIO() as fo:
351-
# write the record to the rest of the buffer
352-
schemaless_writer(fo, parsed_schema, value)
353-
buffer = fo.getvalue()
351+
# Check if it's a simple bytes type
352+
is_bytes = (parsed_schema == "bytes" or
353+
(isinstance(parsed_schema, dict) and parsed_schema.get("type") == "bytes"))
354+
if is_bytes:
355+
# For simple bytes type, write value directly
356+
buffer = value if isinstance(value, bytes) else value.encode()
357+
else:
358+
# write the record to the rest of the buffer
359+
schemaless_writer(fo, parsed_schema, value)
360+
buffer = fo.getvalue()
354361

355362
if latest_schema is not None:
356363
buffer = self._execute_rules_with_phase(
@@ -585,17 +592,29 @@ async def __deserialize(
585592
reader_schema_raw = writer_schema_raw
586593
reader_schema = writer_schema
587594

595+
# Check if it's a simple bytes type
596+
is_bytes = (writer_schema == "bytes" or
597+
(isinstance(writer_schema, dict) and writer_schema.get("type") == "bytes"))
598+
588599
if migrations:
589-
obj_dict = schemaless_reader(payload,
590-
writer_schema,
591-
None,
592-
self._return_record_name)
600+
if is_bytes:
601+
# For simple bytes type, read payload directly
602+
obj_dict = payload.read()
603+
else:
604+
obj_dict = schemaless_reader(payload,
605+
writer_schema,
606+
None,
607+
self._return_record_name)
593608
obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict)
594609
else:
595-
obj_dict = schemaless_reader(payload,
596-
writer_schema,
597-
reader_schema,
598-
self._return_record_name)
610+
if is_bytes:
611+
# For simple bytes type, read payload directly
612+
obj_dict = payload.read()
613+
else:
614+
obj_dict = schemaless_reader(payload,
615+
writer_schema,
616+
reader_schema,
617+
self._return_record_name)
599618

600619
def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731
601620
transform(rule_ctx, reader_schema, message, field_transform))

src/confluent_kafka/schema_registry/_async/schema_registry_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,8 @@ async def send_request(
440440
if body is not None:
441441
body = json.dumps(body)
442442
headers = {'Content-Length': str(len(body)),
443-
'Content-Type': "application/vnd.schemaregistry.v1+json"}
443+
'Content-Type': "application/vnd.schemaregistry.v1+json",
444+
'Accept-Version': "8.0"}
444445

445446
if self.bearer_auth_credentials_source:
446447
await self.handle_bearer_auth(headers)

src/confluent_kafka/schema_registry/_sync/avro.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,16 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731
348348
parsed_schema = self._parsed_schema
349349

350350
with _ContextStringIO() as fo:
351-
# write the record to the rest of the buffer
352-
schemaless_writer(fo, parsed_schema, value)
353-
buffer = fo.getvalue()
351+
# Check if it's a simple bytes type
352+
is_bytes = (parsed_schema == "bytes" or
353+
(isinstance(parsed_schema, dict) and parsed_schema.get("type") == "bytes"))
354+
if is_bytes:
355+
# For simple bytes type, write value directly
356+
buffer = value if isinstance(value, bytes) else value.encode()
357+
else:
358+
# write the record to the rest of the buffer
359+
schemaless_writer(fo, parsed_schema, value)
360+
buffer = fo.getvalue()
354361

355362
if latest_schema is not None:
356363
buffer = self._execute_rules_with_phase(
@@ -585,17 +592,29 @@ def __deserialize(
585592
reader_schema_raw = writer_schema_raw
586593
reader_schema = writer_schema
587594

595+
# Check if it's a simple bytes type
596+
is_bytes = (writer_schema == "bytes" or
597+
(isinstance(writer_schema, dict) and writer_schema.get("type") == "bytes"))
598+
588599
if migrations:
589-
obj_dict = schemaless_reader(payload,
590-
writer_schema,
591-
None,
592-
self._return_record_name)
600+
if is_bytes:
601+
# For simple bytes type, read payload directly
602+
obj_dict = payload.read()
603+
else:
604+
obj_dict = schemaless_reader(payload,
605+
writer_schema,
606+
None,
607+
self._return_record_name)
593608
obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict)
594609
else:
595-
obj_dict = schemaless_reader(payload,
596-
writer_schema,
597-
reader_schema,
598-
self._return_record_name)
610+
if is_bytes:
611+
# For simple bytes type, read payload directly
612+
obj_dict = payload.read()
613+
else:
614+
obj_dict = schemaless_reader(payload,
615+
writer_schema,
616+
reader_schema,
617+
self._return_record_name)
599618

600619
def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731
601620
transform(rule_ctx, reader_schema, message, field_transform))

src/confluent_kafka/schema_registry/_sync/schema_registry_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,8 @@ def send_request(
440440
if body is not None:
441441
body = json.dumps(body)
442442
headers = {'Content-Length': str(len(body)),
443-
'Content-Type': "application/vnd.schemaregistry.v1+json"}
443+
'Content-Type': "application/vnd.schemaregistry.v1+json",
444+
'Accept-Version': "8.0"}
444445

445446
if self.bearer_auth_credentials_source:
446447
self.handle_bearer_auth(headers)

src/confluent_kafka/schema_registry/rules/encryption/awskms/aws_driver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def new_kms_client(self, conf: Dict[str, Any], key_url: Optional[str]) -> KmsCli
5555
role_external_id = conf.get(_ROLE_EXTERNAL_ID)
5656
if role_external_id is None:
5757
role_external_id = os.getenv("AWS_ROLE_EXTERNAL_ID")
58+
role_web_identity_token_file = os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
5859
key = conf.get(_ACCESS_KEY_ID)
5960
secret = conf.get(_SECRET_ACCESS_KEY)
6061
profile = conf.get(_PROFILE)
@@ -74,7 +75,8 @@ def new_kms_client(self, conf: Dict[str, Any], key_url: Optional[str]) -> KmsCli
7475
)
7576
else:
7677
session = boto3.Session(region_name=region)
77-
if role_arn is not None:
78+
# If role_web_identity_token_file is set, use the DefaultCredentialsProvider
79+
if role_arn is not None and role_web_identity_token_file is None:
7880
sts_client = session.client('sts')
7981
params = {
8082
'RoleArn': role_arn,

src/confluent_kafka/schema_registry/rules/encryption/dek_registry/dek_registry_client.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -587,10 +587,7 @@ def register_dek(
587587
encrypted_key_material=encrypted_key_material
588588
)
589589

590-
response = self._rest_client.post('/dek-registry/v1/keks/{}/deks'
591-
.format(urllib.parse.quote(kek_name)),
592-
request.to_dict())
593-
dek = Dek.from_dict(response)
590+
dek = self._create_dek(kek_name, request)
594591

595592
self._dek_cache.set(cache_key, dek)
596593
# Ensure latest dek is invalidated, such as in case of conflict (409)
@@ -611,6 +608,27 @@ def register_dek(
611608

612609
return dek
613610

611+
def _create_dek(
612+
self, kek_name: str, request: CreateDekRequest
613+
) -> Dek:
614+
from confluent_kafka.schema_registry.error import SchemaRegistryError
615+
try:
616+
# Try newer API with subject in the path
617+
path = '/dek-registry/v1/keks/{}/deks/{}'.format(
618+
urllib.parse.quote(kek_name),
619+
urllib.parse.quote(request.subject, safe='')
620+
)
621+
response = self._rest_client.post(path, request.to_dict())
622+
return Dek.from_dict(response)
623+
except SchemaRegistryError as e:
624+
if e.http_status_code == 405:
625+
# Try fallback to older API that does not have subject in the path
626+
path = '/dek-registry/v1/keks/{}/deks'.format(urllib.parse.quote(kek_name))
627+
response = self._rest_client.post(path, request.to_dict())
628+
return Dek.from_dict(response)
629+
else:
630+
raise
631+
614632
def get_dek(
615633
self, kek_name: str, subject: str, algorithm: DekAlgorithm = DekAlgorithm.AES256_GCM,
616634
version: int = 1, deleted: bool = False

tests/schema_registry/_async/test_avro_serdes.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,22 @@ async def test_avro_serialize_use_schema_id():
198198
assert obj == obj2
199199

200200

201+
async def test_avro_serialize_bytes():
202+
conf = {'url': _BASE_URL}
203+
client = AsyncSchemaRegistryClient.new_client(conf)
204+
ser_conf = {'auto.register.schemas': True}
205+
obj = b'\x02\x03\x04'
206+
schema = 'bytes'
207+
ser = await AsyncAvroSerializer(client, schema_str=json.dumps(schema), conf=ser_conf)
208+
ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE)
209+
obj_bytes = await ser(obj, ser_ctx)
210+
assert b'\x00\x00\x00\x00\x01\x02\x03\x04' == obj_bytes
211+
212+
deser = await AsyncAvroDeserializer(client)
213+
obj2 = await deser(obj_bytes, ser_ctx)
214+
assert obj == obj2
215+
216+
201217
async def test_avro_serialize_nested():
202218
conf = {'url': _BASE_URL}
203219
client = AsyncSchemaRegistryClient.new_client(conf)

tests/schema_registry/_sync/test_avro_serdes.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,22 @@ def test_avro_serialize_use_schema_id():
198198
assert obj == obj2
199199

200200

201+
def test_avro_serialize_bytes():
202+
conf = {'url': _BASE_URL}
203+
client = SchemaRegistryClient.new_client(conf)
204+
ser_conf = {'auto.register.schemas': True}
205+
obj = b'\x02\x03\x04'
206+
schema = 'bytes'
207+
ser = AvroSerializer(client, schema_str=json.dumps(schema), conf=ser_conf)
208+
ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE)
209+
obj_bytes = ser(obj, ser_ctx)
210+
assert b'\x00\x00\x00\x00\x01\x02\x03\x04' == obj_bytes
211+
212+
deser = AvroDeserializer(client)
213+
obj2 = deser(obj_bytes, ser_ctx)
214+
assert obj == obj2
215+
216+
201217
def test_avro_serialize_nested():
202218
conf = {'url': _BASE_URL}
203219
client = SchemaRegistryClient.new_client(conf)

0 commit comments

Comments
 (0)