@@ -68,6 +68,31 @@ def get_existing_enum_types():
6868 return enum_types
6969
7070
71+ def get_enum_categories ():
72+ """
73+ Get a mapping of enum types to their categories.
74+
75+ Returns:
76+ A dictionary mapping enum type names to their categories
77+ """
78+ enum_categories = {}
79+
80+ if not ENUMS_DIR .exists ():
81+ return enum_categories
82+
83+ for root , _ , files in os .walk (ENUMS_DIR ):
84+ category = os .path .basename (root )
85+ if category == "enums" : # Skip the root enums directory
86+ continue
87+
88+ for file in files :
89+ if file .endswith ('.py' ) and not file .startswith ('__' ):
90+ enum_name = os .path .splitext (file )[0 ]
91+ enum_categories [enum_name ] = category
92+
93+ return enum_categories
94+
95+
7196def modify_imports_for_enums (files , enum_types ):
7297 """
7398 Modify import statements in generated files to use enum types from the enums package.
@@ -76,17 +101,8 @@ def modify_imports_for_enums(files, enum_types):
76101 files: Dictionary mapping file paths to generated code
77102 enum_types: Set of type names that have enum implementations
78103 """
79- enum_locations = {}
80-
81- # First, find all enum implementations and their locations
82- for root , _ , files_in_dir in os .walk (ENUMS_DIR ):
83- for file in files_in_dir :
84- if file .endswith ('.py' ) and not file .startswith ('__' ):
85- enum_name = os .path .splitext (file )[0 ]
86- # Get the category from the directory structure
87- rel_path = os .path .relpath (root , ENUMS_DIR )
88- category = rel_path if rel_path != '.' else ''
89- enum_locations [enum_name ] = category
104+ # Get enum categories
105+ enum_categories = get_enum_categories ()
90106
91107 # Now update imports in all files
92108 for file_path , content in files .items ():
@@ -97,20 +113,21 @@ def modify_imports_for_enums(files, enum_types):
97113
98114 # Look for import statements for enum types
99115 for enum_type in enum_types :
100- if enum_type not in enum_locations :
116+ if enum_type not in enum_categories :
101117 continue
102118
103- category = enum_locations [enum_type ]
104- category_path = f".{ category } " if category else ""
119+ category = enum_categories [enum_type ]
105120
106121 # Patterns for imports from models (handle various potential patterns)
107122 old_import_patterns = [
108123 f"from msgspec_schemaorg.models.intangible.{ enum_type } import { enum_type } " ,
109- f"from msgspec_schemaorg.models{ category_path } .{ enum_type } import { enum_type } "
124+ f"from msgspec_schemaorg.models.{ category } .{ enum_type } import { enum_type } " ,
125+ f"from .{ enum_type } import { enum_type } " ,
126+ f"from ..{ category } .{ enum_type } import { enum_type } "
110127 ]
111128
112129 # Pattern for corrected import from enums
113- new_import = f"from msgspec_schemaorg.enums{ category_path } .{ enum_type } import { enum_type } "
130+ new_import = f"from msgspec_schemaorg.enums. { category } .{ enum_type } import { enum_type } "
114131
115132 # Replace the import statement
116133 for old_pattern in old_import_patterns :
@@ -135,73 +152,28 @@ def modify_init_files(files, enum_types):
135152 continue
136153
137154 modified_content = content
155+ modified = False
138156
139157 # Remove imports for enum types
140158 for enum_type in enum_types :
141159 import_line = f"from .{ enum_type } import { enum_type } \n "
142160 if import_line in modified_content :
143161 modified_content = modified_content .replace (import_line , "" )
162+ modified = True
144163 print (f"Removed import for { enum_type } in { file_path } " )
145164
146165 # Update __all__ list if present
147166 if "__all__ = [" in modified_content :
148167 # Find __all__ list
149168 all_list_start = modified_content .find ("__all__ = [" )
150169 all_list_end = modified_content .find ("]" , all_list_start )
151- all_list = modified_content [all_list_start :all_list_end + 1 ]
152-
153- # Remove enum types from __all__ list
154- for enum_type in enum_types :
155- # Look for different patterns in __all__ list
156- patterns = [
157- f"'{ enum_type } '," ,
158- f"'{ enum_type } '" ,
159- f"\" { enum_type } \" ," ,
160- f"\" { enum_type } \" "
161- ]
162-
163- for pattern in patterns :
164- if pattern in all_list :
165- # Replace with empty string or just a space depending on location
166- if pattern .endswith ("," ):
167- all_list = all_list .replace (pattern , "" )
168- else :
169- all_list = all_list .replace (pattern , "" )
170- print (f"Removed { enum_type } from __all__ list in { file_path } " )
171-
172- # Clean up any empty commas or double commas
173- all_list = all_list .replace (",," , "," )
174- all_list = all_list .replace (", ," , "," )
175- all_list = all_list .replace ("[," , "[" )
176- all_list = all_list .replace (",]" , "]" )
177-
178- # Replace the old __all__ list with the cleaned up one
179- modified_content = modified_content .replace (
180- modified_content [all_list_start :all_list_end + 1 ],
181- all_list
182- )
183-
184- # Update the content in the files dictionary
185- files [file_path ] = modified_content
186-
187- # Also update the root models/__init__.py if it exists
188- root_init_path = Path (DEFAULT_OUTPUT_DIR ) / "__init__.py"
189- if root_init_path .exists ():
190- try :
191- with open (root_init_path , 'r' ) as f :
192- root_init_content = f .read ()
193-
194- modified_root_init = root_init_content
195-
196- # Update __all__ list if present
197- if "__all__ = [" in modified_root_init :
198- # Find __all__ list
199- all_list_start = modified_root_init .find ("__all__ = [" )
200- all_list_end = modified_root_init .find ("]" , all_list_start )
201- all_list = modified_root_init [all_list_start :all_list_end + 1 ]
170+ if all_list_end > all_list_start : # Ensure we found the end bracket
171+ all_list = modified_content [all_list_start :all_list_end + 1 ]
172+ original_all_list = all_list
202173
203174 # Remove enum types from __all__ list
204175 for enum_type in enum_types :
176+ # Look for different patterns in __all__ list
205177 patterns = [
206178 f"'{ enum_type } '," ,
207179 f"'{ enum_type } '" ,
@@ -213,7 +185,8 @@ def modify_init_files(files, enum_types):
213185 if pattern in all_list :
214186 # Replace with empty string
215187 all_list = all_list .replace (pattern , "" )
216- print (f"Removed { enum_type } from __all__ list in root __init__.py" )
188+ modified = True
189+ print (f"Removed { enum_type } from __all__ list in { file_path } " )
217190
218191 # Clean up any empty commas or double commas
219192 all_list = all_list .replace (",," , "," )
@@ -222,14 +195,68 @@ def modify_init_files(files, enum_types):
222195 all_list = all_list .replace (",]" , "]" )
223196
224197 # Replace the old __all__ list with the cleaned up one
225- modified_root_init = modified_root_init .replace (
226- modified_root_init [all_list_start :all_list_end + 1 ],
227- all_list
228- )
198+ if all_list != original_all_list :
199+ modified_content = modified_content .replace (
200+ original_all_list ,
201+ all_list
202+ )
203+
204+ # Update the content in the files dictionary if modified
205+ if modified :
206+ files [file_path ] = modified_content
207+
208+ # Also update the root models/__init__.py if it exists
209+ root_init_path = Path (DEFAULT_OUTPUT_DIR ) / "__init__.py"
210+ if root_init_path .exists ():
211+ try :
212+ with open (root_init_path , 'r' ) as f :
213+ root_init_content = f .read ()
214+
215+ modified_root_init = root_init_content
216+ modified = False
217+
218+ # Update __all__ list if present
219+ if "__all__ = [" in modified_root_init :
220+ # Find __all__ list
221+ all_list_start = modified_root_init .find ("__all__ = [" )
222+ all_list_end = modified_root_init .find ("]" , all_list_start )
223+ if all_list_end > all_list_start : # Ensure we found the end bracket
224+ all_list = modified_root_init [all_list_start :all_list_end + 1 ]
225+ original_all_list = all_list
226+
227+ # Remove enum types from __all__ list
228+ for enum_type in enum_types :
229+ patterns = [
230+ f"'{ enum_type } '," ,
231+ f"'{ enum_type } '" ,
232+ f"\" { enum_type } \" ," ,
233+ f"\" { enum_type } \" "
234+ ]
235+
236+ for pattern in patterns :
237+ if pattern in all_list :
238+ # Replace with empty string
239+ all_list = all_list .replace (pattern , "" )
240+ modified = True
241+ print (f"Removed { enum_type } from __all__ list in root __init__.py" )
242+
243+ # Clean up any empty commas or double commas
244+ all_list = all_list .replace (",," , "," )
245+ all_list = all_list .replace (", ," , "," )
246+ all_list = all_list .replace ("[," , "[" )
247+ all_list = all_list .replace (",]" , "]" )
248+
249+ # Replace the old __all__ list with the cleaned up one
250+ if all_list != original_all_list :
251+ modified_root_init = modified_root_init .replace (
252+ original_all_list ,
253+ all_list
254+ )
229255
230- # Write the modified content back to the file
231- with open (root_init_path , 'w' ) as f :
232- f .write (modified_root_init )
256+ # Write the modified content back to the file if changed
257+ if modified :
258+ with open (root_init_path , 'w' ) as f :
259+ f .write (modified_root_init )
233260 except Exception as e :
234261 print (f"Error updating root __init__.py: { e } " )
235262
0 commit comments