11from collections import OrderedDict
2+ from functools import partial
23from graphql .core .type import GraphQLObjectType
34from ..utils .get_declared_fields import get_declared_fields
45from ..utils .make_default_resolver import make_default_resolver
56from ..utils .no_implementation_registration import no_implementation_registration
7+ from ..utils .ref_holder import RefHolder
68from ..utils .yank_potential_fields import yank_potential_fields
79
810
911class ObjectTypeMeta (type ):
1012 def __new__ (mcs , name , bases , attrs ):
11- if attrs .get ('abstract' ):
13+ if attrs .pop ('abstract' , False ):
1214 return super (ObjectTypeMeta , mcs ).__new__ (mcs , name , bases , attrs )
1315
16+ class_ref = RefHolder ()
1417 declared_fields = get_declared_fields (name , yank_potential_fields (attrs ))
1518 with no_implementation_registration ():
1619 object_type = GraphQLObjectType (
1720 name ,
18- fields = lambda : mcs ._build_field_map ( attrs , declared_fields ),
21+ fields = partial ( mcs ._build_field_map , class_ref , declared_fields ),
1922 description = attrs .get ('__doc__' ),
2023 interfaces = mcs ._get_interfaces ()
2124 )
2225
2326 mcs ._register (object_type )
24- attrs ['_registry' ] = mcs ._get_registry ()
25- attrs ['T' ] = object_type
2627 cls = super (ObjectTypeMeta , mcs ).__new__ (mcs , name , bases , attrs )
27- attrs ['_cls' ] = cls
28+ cls .T = object_type
29+ cls ._registry = mcs ._get_registry ()
30+ class_ref .set (cls )
31+
2832 return cls
2933
3034 @staticmethod
@@ -36,10 +40,14 @@ def _get_registry():
3640 raise NotImplementedError ('_get_registry must be implemented in the sub-metaclass' )
3741
3842 @staticmethod
39- def _build_field_map (attrs , declared_fields ):
40- instance = attrs ['_cls' ]()
41- type = attrs ['T' ]
42- registry = attrs ['_registry' ]
43+ def _build_field_map (class_ref , declared_fields ):
44+ cls = class_ref .get ()
45+ if not cls :
46+ return
47+
48+ instance = cls (__field_map_init = True )
49+ type = cls .T
50+ registry = cls ._registry
4351 interfaces = type .get_interfaces ()
4452 fields = []
4553
@@ -54,6 +62,7 @@ def _build_field_map(attrs, declared_fields):
5462
5563 fields += declared_fields
5664 field_map = OrderedDict ()
65+ field_attr_map = OrderedDict ()
5766
5867 for field_attr_name , field in fields :
5968 resolve_fn = (
@@ -70,8 +79,15 @@ def _build_field_map(attrs, declared_fields):
7079 if field .name in field_map :
7180 del field_map [field .name ]
7281
73- field_map [field .name ] = field .to_field (registry , resolve_fn )
82+ graphql_field = field .to_field (registry , resolve_fn )
83+ field_map [field .name ] = graphql_field
84+
85+ if field_attr_name in field_attr_map :
86+ del field_attr_map [field_attr_name ]
87+
88+ field_attr_map [field_attr_name ] = graphql_field
7489
90+ cls ._field_attr_map = field_attr_map
7591 return field_map
7692
7793 @staticmethod
0 commit comments