From c97ae84266ff481936b6120f51e80f2b3d5748e4 Mon Sep 17 00:00:00 2001 From: etareduction Date: Mon, 29 Sep 2025 06:13:30 +0300 Subject: [PATCH 1/2] Add a setting that sets default=None to all optional fields so they can be missing from response and still parse --- ariadne_codegen/client_generators/fragments.py | 3 +++ ariadne_codegen/client_generators/package.py | 5 +++++ ariadne_codegen/client_generators/result_types.py | 10 ++++++++++ ariadne_codegen/settings.py | 1 + 4 files changed, 19 insertions(+) diff --git a/ariadne_codegen/client_generators/fragments.py b/ariadne_codegen/client_generators/fragments.py index 3a7d2ecc..13409d2f 100644 --- a/ariadne_codegen/client_generators/fragments.py +++ b/ariadne_codegen/client_generators/fragments.py @@ -20,6 +20,7 @@ def __init__( convert_to_snake_case: bool = True, custom_scalars: Optional[Dict[str, ScalarData]] = None, plugin_manager: Optional[PluginManager] = None, + default_optional_fields_to_none: bool = False, ) -> None: self.schema = schema self.enums_module_name = enums_module_name @@ -28,6 +29,7 @@ def __init__( self.convert_to_snake_case = convert_to_snake_case self.custom_scalars = custom_scalars self.plugin_manager = plugin_manager + self.default_optional_fields_to_none = default_optional_fields_to_none self._fragments_names = set(self.fragments_definitions.keys()) self._generated_public_names: List[str] = [] @@ -52,6 +54,7 @@ def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module: convert_to_snake_case=self.convert_to_snake_case, custom_scalars=self.custom_scalars, plugin_manager=self.plugin_manager, + default_optional_fields_to_none=self.default_optional_fields_to_none, ) imports.extend(generator.get_imports()) class_defs = generator.get_classes() diff --git a/ariadne_codegen/client_generators/package.py b/ariadne_codegen/client_generators/package.py index 529ad7ba..8797d21e 100644 --- a/ariadne_codegen/client_generators/package.py +++ b/ariadne_codegen/client_generators/package.py @@ -84,6 +84,7 @@ def __init__( custom_scalars: Optional[Dict[str, ScalarData]] = None, plugin_manager: Optional[PluginManager] = None, enable_custom_operations: bool = False, + default_optional_fields_to_none: bool = False, ) -> None: self.package_path = Path(target_path) / package_name @@ -133,6 +134,7 @@ def __init__( ) self.custom_scalars = custom_scalars if custom_scalars else {} self.plugin_manager = plugin_manager + self.default_optional_fields_to_none = default_optional_fields_to_none self._result_types_files: Dict[str, ast.Module] = {} self._generated_files: List[str] = [] @@ -199,6 +201,7 @@ def add_operation(self, definition: OperationDefinitionNode): convert_to_snake_case=self.convert_to_snake_case, custom_scalars=self.custom_scalars, plugin_manager=self.plugin_manager, + default_optional_fields_to_none=self.default_optional_fields_to_none, ) self._unpacked_fragments = self._unpacked_fragments.union( query_types_generator.get_unpacked_fragments() @@ -454,6 +457,7 @@ def get_package_generator( convert_to_snake_case=settings.convert_to_snake_case, custom_scalars=settings.scalars, plugin_manager=plugin_manager, + default_optional_fields_to_none=settings.default_optional_fields_to_none, ) custom_fields_generator = CustomFieldsGenerator( schema=schema, @@ -533,4 +537,5 @@ def get_package_generator( custom_scalars=settings.scalars, plugin_manager=plugin_manager, enable_custom_operations=settings.enable_custom_operations, + default_optional_fields_to_none=settings.default_optional_fields_to_none, ) diff --git a/ariadne_codegen/client_generators/result_types.py b/ariadne_codegen/client_generators/result_types.py index bf29ec46..c4ed026b 100644 --- a/ariadne_codegen/client_generators/result_types.py +++ b/ariadne_codegen/client_generators/result_types.py @@ -85,6 +85,7 @@ def __init__( convert_to_snake_case: bool = True, custom_scalars: Optional[Dict[str, ScalarData]] = None, plugin_manager: Optional[PluginManager] = None, + default_optional_fields_to_none: bool = False, ) -> None: self.schema = schema self.operation_definition = operation_definition @@ -99,6 +100,7 @@ def __init__( self.custom_scalars = custom_scalars if custom_scalars else {} self.convert_to_snake_case = convert_to_snake_case self.plugin_manager = plugin_manager + self.default_optional_fields_to_none = default_optional_fields_to_none self._imports: List[ast.ImportFrom] = [ generate_import_from( @@ -443,6 +445,14 @@ def _process_field_implementation( keywords[DEFAULT_KEYWORD] = generate_constant( field_implementation.value.value ) + elif ( + self.default_optional_fields_to_none + and field_implementation.value is None + and isinstance(field_implementation.annotation, ast.Subscript) + and isinstance(field_implementation.annotation.value, ast.Name) + and field_implementation.annotation.value.id == OPTIONAL + ): + keywords[DEFAULT_KEYWORD] = generate_constant(None) if keywords: field_implementation.value = generate_pydantic_field(keywords) diff --git a/ariadne_codegen/settings.py b/ariadne_codegen/settings.py index 663148a8..8e651f38 100644 --- a/ariadne_codegen/settings.py +++ b/ariadne_codegen/settings.py @@ -72,6 +72,7 @@ class ClientSettings(BaseSettings): opentelemetry_client: bool = False files_to_include: List[str] = field(default_factory=list) scalars: Dict[str, ScalarData] = field(default_factory=dict) + default_optional_fields_to_none: bool = False def __post_init__(self): if not self.queries_path and not self.enable_custom_operations: From b660e5a8c14f7e340d3f8f02637c4f1fee8acd72 Mon Sep 17 00:00:00 2001 From: GoodGrief1488 <55449535+GoodGrief1488@users.noreply.github.com> Date: Tue, 21 Oct 2025 22:15:20 +0300 Subject: [PATCH 2/2] Unit tests for default_optional_fields_to_none setting --- .../test_default_optional_fields.py | 138 ++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 tests/client_generators/result_types_generator/test_default_optional_fields.py diff --git a/tests/client_generators/result_types_generator/test_default_optional_fields.py b/tests/client_generators/result_types_generator/test_default_optional_fields.py new file mode 100644 index 00000000..d00b85e7 --- /dev/null +++ b/tests/client_generators/result_types_generator/test_default_optional_fields.py @@ -0,0 +1,138 @@ +import ast +from typing import cast + +from graphql import ( + OperationDefinitionNode, + build_ast_schema, + parse, +) + +from ariadne_codegen.client_generators.constants import ( + ALIAS_KEYWORD, + DEFAULT_KEYWORD, + FIELD_CLASS, + OPTIONAL, +) +from ariadne_codegen.client_generators.result_types import ResultTypesGenerator + +from ...utils import compare_ast, format_graphql_str, get_class_def +from .schema import SCHEMA_STR + + +def test_default_optional_fields_true(): + query_str = format_graphql_str( + """ + query CustomQuery { + query1 { + ... on CustomType { + field1 + field2 + } + } + } + """ + ) + expected_results = [ + ast.AnnAssign( + target=ast.Name(id="field_1"), + annotation=ast.Name(id='"CustomQueryQuery1Field1"'), + value=ast.Call( + func=ast.Name(id=FIELD_CLASS), + args=[], + keywords=[ + ast.keyword( + arg=ALIAS_KEYWORD, + value=ast.Constant(value="field1"), + ) + ], + ), + simple=1, + ), + ast.AnnAssign( + target=ast.Name(id="field_2"), + annotation=ast.Subscript( + value=ast.Name(id=OPTIONAL), + slice=ast.Name(id='"CustomQueryQuery1Field2"'), + ), + value=ast.Call( + func=ast.Name(id=FIELD_CLASS), + args=[], + keywords=[ + ast.keyword(arg=ALIAS_KEYWORD, value=ast.Constant(value="field2")), + ast.keyword(arg=DEFAULT_KEYWORD, value=ast.Constant(value=None)), + ], + ), + simple=1, + ), + ] + generator = ResultTypesGenerator( + schema=build_ast_schema(parse(SCHEMA_STR)), + operation_definition=cast( + OperationDefinitionNode, parse(query_str).definitions[0] + ), + enums_module_name="enums", + default_optional_fields_to_none=True, + ) + result = generator.generate() + classdef = get_class_def(result, 1) + assert compare_ast(classdef.body[0], expected_results[0]) + assert compare_ast(classdef.body[1], expected_results[1]) + + +def test_default_optional_fields_false(): + query_str = format_graphql_str( + """ + query CustomQuery { + query1 { + ... on CustomType { + field1 + field2 + } + } + } + """ + ) + expected_results = [ + ast.AnnAssign( + target=ast.Name(id="field_1"), + annotation=ast.Name(id='"CustomQueryQuery1Field1"'), + value=ast.Call( + func=ast.Name(id=FIELD_CLASS), + args=[], + keywords=[ + ast.keyword( + arg=ALIAS_KEYWORD, + value=ast.Constant(value="field1"), + ) + ], + ), + simple=1, + ), + ast.AnnAssign( + target=ast.Name(id="field_2"), + annotation=ast.Subscript( + value=ast.Name(id=OPTIONAL), + slice=ast.Name(id='"CustomQueryQuery1Field2"'), + ), + value=ast.Call( + func=ast.Name(id=FIELD_CLASS), + args=[], + keywords=[ + ast.keyword(arg=ALIAS_KEYWORD, value=ast.Constant(value="field2")) + ], + ), + simple=1, + ), + ] + generator = ResultTypesGenerator( + schema=build_ast_schema(parse(SCHEMA_STR)), + operation_definition=cast( + OperationDefinitionNode, parse(query_str).definitions[0] + ), + enums_module_name="enums", + default_optional_fields_to_none=False, + ) + result = generator.generate() + classdef = get_class_def(result, 1) + assert compare_ast(classdef.body[0], expected_results[0]) + assert compare_ast(classdef.body[1], expected_results[1])