@@ -453,7 +453,9 @@ def class_schema(
453453 >>> class_schema(Custom)().load({})
454454 Custom(name=None)
455455 """
456- if not dataclasses .is_dataclass (clazz ):
456+ if not dataclasses .is_dataclass (clazz ) and not _is_generic_alias_of_dataclass (
457+ clazz
458+ ):
457459 clazz = dataclasses .dataclass (clazz )
458460 if localns is None :
459461 if clazz_frame is None :
@@ -523,8 +525,7 @@ def _internal_class_schema(
523525 schema_ctx .seen_classes [clazz ] = class_name
524526
525527 try :
526- # noinspection PyDataclass
527- fields : Tuple [dataclasses .Field , ...] = dataclasses .fields (clazz )
528+ class_name , fields = _dataclass_name_and_fields (clazz )
528529 except TypeError : # Not a dataclass
529530 try :
530531 warnings .warn (
@@ -582,7 +583,7 @@ def _internal_class_schema(
582583 if field .init or include_non_init
583584 )
584585
585- schema_class = type (clazz . __name__ , (_base_schema (clazz , base_schema ),), attributes )
586+ schema_class = type (class_name , (_base_schema (clazz , base_schema ),), attributes )
586587 return cast (Type [marshmallow .Schema ], schema_class )
587588
588589
@@ -996,6 +997,47 @@ def _get_field_default(field: dataclasses.Field):
996997 return field .default
997998
998999
1000+ def _is_generic_alias_of_dataclass (clazz : type ) -> bool :
1001+ """
1002+ Check if given class is a generic alias of a dataclass, if the dataclass is
1003+ defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
1004+ """
1005+ return typing_inspect .is_generic_type (clazz ) and dataclasses .is_dataclass (
1006+ typing_inspect .get_origin (clazz )
1007+ )
1008+
1009+
1010+ # noinspection PyDataclass
1011+ def _dataclass_name_and_fields (
1012+ clazz : type ,
1013+ ) -> Tuple [str , Tuple [dataclasses .Field , ...]]:
1014+ if not _is_generic_alias_of_dataclass (clazz ):
1015+ return clazz .__name__ , dataclasses .fields (clazz )
1016+
1017+ base_dataclass = typing_inspect .get_origin (clazz )
1018+ base_parameters = typing_inspect .get_parameters (base_dataclass )
1019+ type_arguments = typing_inspect .get_args (clazz )
1020+ params_to_args = dict (zip (base_parameters , type_arguments ))
1021+ non_generic_fields = [ # swap generic typed fields with types in given type arguments
1022+ (
1023+ f .name ,
1024+ params_to_args .get (f .type , f .type ),
1025+ dataclasses .field (
1026+ default = f .default ,
1027+ # ignoring mypy: https://github.com/python/mypy/issues/6910
1028+ default_factory = f .default_factory , # type: ignore
1029+ init = f .init ,
1030+ metadata = f .metadata ,
1031+ ),
1032+ )
1033+ for f in dataclasses .fields (base_dataclass )
1034+ ]
1035+ non_generic_dataclass = dataclasses .make_dataclass (
1036+ cls_name = f"{ base_dataclass .__name__ } { type_arguments } " , fields = non_generic_fields
1037+ )
1038+ return base_dataclass .__name__ , dataclasses .fields (non_generic_dataclass )
1039+
1040+
9991041def NewType (
10001042 name : str ,
10011043 typ : Type [_U ],
0 commit comments