1+ from collections import defaultdict
12from enum import Enum
2-
3+ from functools import partial
34from graphql .core .type import (
45 GraphQLBoolean ,
56 GraphQLEnumType ,
1920
2021import six
2122from .bases .object_type import ObjectTypeBase
23+ from .bases .class_type_creator import ClassTypeCreator
2224from .field import Field
2325from .metaclasses .interface import InterfaceMeta
2426from .metaclasses .object_type import ObjectTypeMeta
@@ -44,8 +46,10 @@ def __init__(self):
4446 self ._registered_types = {}
4547 self ._added_impl_types = set ()
4648 self ._interface_declared_fields = {}
49+ self ._registered_types_can_be = defaultdict (set )
4750 self .ObjectType = self ._create_object_type_class ()
48- self .Implements = self ._create_implement_type_class ()
51+ self .Implements = ClassTypeCreator (self , self ._create_object_type_class )
52+ self .Union = ClassTypeCreator (self , self ._create_union_type_class )
4953 self .Interface = self ._create_interface_type_class ()
5054
5155 for type in builtin_scalars :
@@ -66,8 +70,13 @@ def register(self, t):
6670 @register .register (GraphQLInputObjectType )
6771 @register .register (GraphQLScalarType )
6872 def register_ (self , t ):
69- assert t .name not in ('ObjectType' , 'Implements' , 'Interface' )
70- assert t .name not in self ._registered_types
73+ assert not t .name .startswith ('_' ), \
74+ 'Registered type name cannot start with an "_".'
75+ assert t .name not in ('ObjectType' , 'Implements' , 'Interface' , 'Schema' ), \
76+ 'You cannot register a type named "{}".' .format (type .name )
77+ assert t .name not in self ._registered_types , \
78+ 'There is already a registered type named "{}".' .format (type .name )
79+
7180 self ._registered_types [t .name ] = t
7281 return t
7382
@@ -103,8 +112,9 @@ def _create_object_type_class(self, interface_thunk=None):
103112
104113 class RegistryObjectTypeMeta (ObjectTypeMeta ):
105114 @staticmethod
106- def _register (object_type ):
115+ def _register (object_type , type_class ):
107116 registry .register (object_type )
117+ registry ._registered_types_can_be [object_type ].add (type_class )
108118
109119 @staticmethod
110120 def _get_registry ():
@@ -123,25 +133,6 @@ class ObjectType(ObjectTypeBase):
123133
124134 return ObjectType
125135
126- def _create_implement_type_class (self ):
127- registry = self
128-
129- class Implements (object ):
130- def __getattr__ (self , item ):
131- return self [item ]
132-
133- def __getitem__ (self , item ):
134- if isinstance (item , tuple ):
135- type_thunk = ThunkList ([ResolveThunk (registry ._resolve_type , i ) for i in item ])
136-
137- else :
138- type_thunk = ThunkList ([ResolveThunk (registry ._resolve_type , item )])
139-
140- return registry ._create_object_type_class (type_thunk )
141-
142- implements = Implements ()
143- return implements
144-
145136 def _create_interface_type_class (self ):
146137 registry = self
147138
@@ -160,18 +151,41 @@ class Interface(six.with_metaclass(RegistryInterfaceMeta)):
160151
161152 return Interface
162153
154+ def _create_union_type_class (self , types_thunk ):
155+ registry = self
156+
157+ class RegistryUnionMeta (UnionMeta ):
158+ @staticmethod
159+ def _register (union ):
160+ registry .register (union )
161+
162+ @staticmethod
163+ def _get_registry ():
164+ return registry
165+
166+ class Union (six .with_metaclass (RegistryUnionMeta )):
167+ abstract = True
168+
169+ @staticmethod
170+ def _get_types ():
171+ return TransformThunkList (types_thunk , get_named_type )
172+
173+ return Union
174+
175+ def _create_is_type_of (self , type ):
176+ return partial (self ._is_type_of , type )
177+
178+ def _is_type_of (self , type , obj , info ):
179+ return obj .__class__ in self ._registered_types_can_be [type ]
180+
163181 def _add_interface_declared_fields (self , interface , attrs ):
164182 self ._interface_declared_fields [interface ] = attrs
165183
166184 def _get_interface_declared_fields (self , interface ):
167185 return self ._interface_declared_fields .get (interface , {})
168186
169- def _add_impl_to_interfaces (self , * types ):
170- type_map = {}
171- for type in types :
172- type_map = type_map_reducer (type_map , type )
173-
174- for type in type_map :
187+ def _add_impl_to_interfaces (self ):
188+ for type in self ._registered_types .values ():
175189 if not isinstance (type , GraphQLObjectType ):
176190 continue
177191
@@ -185,10 +199,10 @@ def _add_impl_to_interfaces(self, *types):
185199
186200 interface ._impls .append (type )
187201
188- def schema (self , query , mutation = None ):
202+ def Schema (self , query , mutation = None ):
189203 query = self [query ]()
190204 mutation = self [mutation ]()
191- self ._add_impl_to_interfaces (query , mutation )
205+ self ._add_impl_to_interfaces ()
192206 return GraphQLSchema (query = query , mutation = mutation )
193207
194208 def type (self , name ):
0 commit comments