@@ -44,15 +44,9 @@ class User:
4444import warnings
4545from enum import Enum
4646from functools import lru_cache , partial
47+ from typing import Any , Callable , Dict , FrozenSet , Generic , List , Mapping
48+ from typing import NewType as typing_NewType
4749from typing import (
48- Any ,
49- Callable ,
50- Dict ,
51- FrozenSet ,
52- Generic ,
53- List ,
54- Mapping ,
55- NewType as typing_NewType ,
5650 Optional ,
5751 Sequence ,
5852 Set ,
@@ -150,8 +144,7 @@ def dataclass(
150144 frozen : bool = False ,
151145 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
152146 cls_frame : Optional [types .FrameType ] = None ,
153- ) -> Type [_U ]:
154- ...
147+ ) -> Type [_U ]: ...
155148
156149
157150@overload
@@ -164,8 +157,7 @@ def dataclass(
164157 frozen : bool = False ,
165158 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
166159 cls_frame : Optional [types .FrameType ] = None ,
167- ) -> Callable [[Type [_U ]], Type [_U ]]:
168- ...
160+ ) -> Callable [[Type [_U ]], Type [_U ]]: ...
169161
170162
171163# _cls should never be specified by keyword, so start it with an
@@ -224,15 +216,13 @@ def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]:
224216
225217
226218@overload
227- def add_schema (_cls : Type [_U ]) -> Type [_U ]:
228- ...
219+ def add_schema (_cls : Type [_U ]) -> Type [_U ]: ...
229220
230221
231222@overload
232223def add_schema (
233224 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
234- ) -> Callable [[Type [_U ]], Type [_U ]]:
235- ...
225+ ) -> Callable [[Type [_U ]], Type [_U ]]: ...
236226
237227
238228@overload
@@ -241,8 +231,7 @@ def add_schema(
241231 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
242232 cls_frame : Optional [types .FrameType ] = None ,
243233 stacklevel : int = 1 ,
244- ) -> Type [_U ]:
245- ...
234+ ) -> Type [_U ]: ...
246235
247236
248237def add_schema (_cls = None , base_schema = None , cls_frame = None , stacklevel = 1 ):
@@ -293,8 +282,7 @@ def class_schema(
293282 * ,
294283 globalns : Optional [Dict [str , Any ]] = None ,
295284 localns : Optional [Dict [str , Any ]] = None ,
296- ) -> Type [marshmallow .Schema ]:
297- ...
285+ ) -> Type [marshmallow .Schema ]: ...
298286
299287
300288@overload
@@ -304,8 +292,7 @@ def class_schema(
304292 clazz_frame : Optional [types .FrameType ] = None ,
305293 * ,
306294 globalns : Optional [Dict [str , Any ]] = None ,
307- ) -> Type [marshmallow .Schema ]:
308- ...
295+ ) -> Type [marshmallow .Schema ]: ...
309296
310297
311298def class_schema (
@@ -463,7 +450,7 @@ def class_schema(
463450 if clazz_frame is not None :
464451 localns = clazz_frame .f_locals
465452 with _SchemaContext (globalns , localns ):
466- return _internal_class_schema (clazz , base_schema )
453+ return _internal_class_schema (clazz , base_schema , None )
467454
468455
469456class _SchemaContext :
@@ -509,10 +496,17 @@ def top(self) -> _U:
509496_schema_ctx_stack = _LocalStack [_SchemaContext ]()
510497
511498
499+ def _dataclass_fields (clazz : type ) -> Tuple [dataclasses .Field , ...]:
500+ if _is_generic_alias_of_dataclass (clazz ):
501+ clazz = typing_inspect .get_origin (clazz )
502+ return dataclasses .fields (clazz )
503+
504+
512505@lru_cache (maxsize = MAX_CLASS_SCHEMA_CACHE_SIZE )
513506def _internal_class_schema (
514507 clazz : type ,
515508 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
509+ generic_params_to_args : Optional [Tuple [Tuple [type , type ], ...]] = None ,
516510) -> Type [marshmallow .Schema ]:
517511 schema_ctx = _schema_ctx_stack .top
518512
@@ -525,7 +519,7 @@ def _internal_class_schema(
525519 schema_ctx .seen_classes [clazz ] = class_name
526520
527521 try :
528- class_name , fields = _dataclass_name_and_fields (clazz )
522+ fields = _dataclass_fields (clazz )
529523 except TypeError : # Not a dataclass
530524 try :
531525 warnings .warn (
@@ -540,7 +534,9 @@ def _internal_class_schema(
540534 "****** WARNING ******"
541535 )
542536 created_dataclass : type = dataclasses .dataclass (clazz )
543- return _internal_class_schema (created_dataclass , base_schema )
537+ return _internal_class_schema (
538+ created_dataclass , base_schema , generic_params_to_args
539+ )
544540 except Exception as exc :
545541 raise TypeError (
546542 f"{ getattr (clazz , '__name__' , repr (clazz ))} is not a dataclass and cannot be turned into one."
@@ -556,6 +552,10 @@ def _internal_class_schema(
556552 # Determine whether we should include non-init fields
557553 include_non_init = getattr (getattr (clazz , "Meta" , None ), "include_non_init" , False )
558554
555+ if _is_generic_alias_of_dataclass (clazz ) and generic_params_to_args is None :
556+ generic_params_to_args = _generic_params_to_args (clazz )
557+
558+ type_hints = _dataclass_type_hints (clazz , schema_ctx , generic_params_to_args )
559559 # Update the schema members to contain marshmallow fields instead of dataclass fields
560560
561561 if sys .version_info >= (3 , 9 ):
@@ -577,13 +577,14 @@ def _internal_class_schema(
577577 _get_field_default (field ),
578578 field .metadata ,
579579 base_schema ,
580+ generic_params_to_args ,
580581 ),
581582 )
582583 for field in fields
583584 if field .init or include_non_init
584585 )
585586
586- schema_class = type (class_name , (_base_schema (clazz , base_schema ),), attributes )
587+ schema_class = type (clazz . __name__ , (_base_schema (clazz , base_schema ),), attributes )
587588 return cast (Type [marshmallow .Schema ], schema_class )
588589
589590
@@ -706,7 +707,7 @@ def _field_for_generic_type(
706707 ),
707708 )
708709 return tuple_type (children , ** metadata )
709- elif origin in (dict , Dict , collections .abc .Mapping , Mapping ):
710+ if origin in (dict , Dict , collections .abc .Mapping , Mapping ):
710711 dict_type = type_mapping .get (Dict , marshmallow .fields .Dict )
711712 return dict_type (
712713 keys = _field_for_schema (arguments [0 ], base_schema = base_schema ),
@@ -794,6 +795,7 @@ def field_for_schema(
794795 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
795796 # FIXME: delete typ_frame from API?
796797 typ_frame : Optional [types .FrameType ] = None ,
798+ generic_params_to_args : Optional [Tuple [Tuple [type , type ], ...]] = None ,
797799) -> marshmallow .fields .Field :
798800 """
799801 Get a marshmallow Field corresponding to the given python type.
@@ -953,7 +955,7 @@ def _field_for_schema(
953955 nested_schema
954956 or forward_reference
955957 or _schema_ctx_stack .top .seen_classes .get (typ )
956- or _internal_class_schema (typ , base_schema ) # type: ignore[arg-type] # FIXME
958+ or _internal_class_schema (typ , base_schema , generic_params_to_args ) # type: ignore [arg-type]
957959 )
958960
959961 return marshmallow .fields .Nested (nested , ** metadata )
@@ -1007,35 +1009,38 @@ def _is_generic_alias_of_dataclass(clazz: type) -> bool:
10071009 )
10081010
10091011
1010- # noinspection PyDataclass
1011- def _dataclass_name_and_fields (
1012- clazz : type ,
1013- ) -> Tuple [str , Tuple [dataclasses .Field , ...]]:
1014- if not _is_generic_alias_of_dataclass (clazz ):
1015- return clazz .__name__ , dataclasses .fields (clazz )
1016-
1012+ def _generic_params_to_args (clazz : type ) -> Tuple [Tuple [type , type ], ...]:
10171013 base_dataclass = typing_inspect .get_origin (clazz )
10181014 base_parameters = typing_inspect .get_parameters (base_dataclass )
10191015 type_arguments = typing_inspect .get_args (clazz )
1020- params_to_args = dict (zip (base_parameters , type_arguments ))
1021- non_generic_fields = [ # swap generic typed fields with types in given type arguments
1022- (
1023- f .name ,
1024- params_to_args .get (f .type , f .type ),
1025- dataclasses .field (
1026- default = f .default ,
1027- # ignoring mypy: https://github.com/python/mypy/issues/6910
1028- default_factory = f .default_factory , # type: ignore
1029- init = f .init ,
1030- metadata = f .metadata ,
1031- ),
1016+ return tuple (zip (base_parameters , type_arguments ))
1017+
1018+
1019+ def _dataclass_type_hints (
1020+ clazz : type ,
1021+ schema_ctx : _SchemaContext = None ,
1022+ generic_params_to_args : Optional [Tuple [Tuple [type , type ], ...]] = None ,
1023+ ) -> Mapping [str , type ]:
1024+ if not _is_generic_alias_of_dataclass (clazz ):
1025+ return get_type_hints (
1026+ clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns
10321027 )
1033- for f in dataclasses .fields (base_dataclass )
1034- ]
1035- non_generic_dataclass = dataclasses .make_dataclass (
1036- cls_name = f"{ base_dataclass .__name__ } { type_arguments } " , fields = non_generic_fields
1028+ # dataclass is generic
1029+ generic_type_hints = get_type_hints (
1030+ typing_inspect .get_origin (clazz ),
1031+ globalns = schema_ctx .globalns ,
1032+ localns = schema_ctx .localns ,
10371033 )
1038- return base_dataclass .__name__ , dataclasses .fields (non_generic_dataclass )
1034+ generic_params_map = dict (generic_params_to_args if generic_params_to_args else {})
1035+
1036+ def _get_hint (_t : type ) -> type :
1037+ if isinstance (_t , TypeVar ):
1038+ return generic_params_map [_t ]
1039+ return _t
1040+
1041+ return {
1042+ field_name : _get_hint (typ ) for field_name , typ in generic_type_hints .items ()
1043+ }
10391044
10401045
10411046def NewType (
0 commit comments