From 643ac2408d5276d930ef81337743cdf919dcec89 Mon Sep 17 00:00:00 2001 From: Anatole Storck Date: Wed, 3 Dec 2025 19:10:15 +0100 Subject: [PATCH] [feat] add parallelization for smooth covering grids --- .../construction_data_containers.py | 41 +++++++++++++++---- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/yt/data_objects/construction_data_containers.py b/yt/data_objects/construction_data_containers.py index 95fce4002a2..c22a78d6c48 100644 --- a/yt/data_objects/construction_data_containers.py +++ b/yt/data_objects/construction_data_containers.py @@ -1118,7 +1118,8 @@ def _fill_fields(self, fields): if not is_sequence(self.ds.refine_by): refine_by = [refine_by, refine_by, refine_by] refine_by = np.array(refine_by, dtype="i8") - for chunk in parallel_objects(self._data_source.chunks(fields, "io")): + for chunk in self._data_source.piter(): + chunk.get_data(fields) input_fields = [chunk[field] for field in fields] # NOTE: This usage of "refine_by" is actually *okay*, because it's # being used with respect to iref, which is *already* scaled! @@ -1478,12 +1479,13 @@ def _compute_minimum_level(self): ils.data_source.max_level = l ils.data_source.loose_selection = False min_level = self.level - for chunk in ils.data_source.chunks([], "io"): + for chunk in ils.data_source.piter(): # With our odd selection methods, we can sometimes get no-sized ires. ir = chunk.ires if ir.size == 0: continue min_level = min(ir.min(), min_level) + min_level = self.comm.mpi_allreduce(min_level, op="min") if min_level >= l: break self._min_level = min_level @@ -1507,19 +1509,29 @@ def _fill_fields(self, fields): if level < min_level: self._update_level_state(ls) continue + + mylog.debug("Filling level %d", level) + nd = self.ds.dimensionality refinement = np.zeros_like(ls.base_dx) refinement += self.ds.relative_refinement(0, ls.current_level) refinement[nd:] = 1 domain_dims = self.ds.domain_dimensions * refinement domain_dims = domain_dims.astype("int64") - tot = ls.current_dims.prod() - for chunk in ls.data_source.chunks(fields, "io"): - chunk[fields[0]] + + initial_tot = ls.current_dims.prod() + filled_cells = 0 + + output_fields = [ + np.zeros_like(ls.fields[0], dtype="float64") for field in fields + ] + + for chunk in ls.data_source.piter(): + chunk.get_data(fields) input_fields = [chunk[field] for field in fields] - tot -= fill_region( + filled_cells += fill_region( input_fields, - ls.fields, + output_fields, ls.current_level, ls.global_startindex, chunk.icoords, @@ -1527,7 +1539,18 @@ def _fill_fields(self, fields): domain_dims, refine_by, ) - if level == 0 and tot != 0: + + filled_cells = self.comm.mpi_allreduce(filled_cells, op="sum") + for i in range(len(fields)): + output_fields[i] = self.comm.mpi_allreduce(output_fields[i], op="sum") + + leftover_cells = initial_tot - filled_cells + + for i in range(len(fields)): + replace_mask = output_fields[i] != 0 + ls.fields[i][replace_mask] = output_fields[i][replace_mask] + + if level == 0 and leftover_cells != 0: runtime_errors_count += 1 self._update_level_state(ls) if runtime_errors_count: @@ -1570,7 +1593,7 @@ def _initialize_level_state(self, fields): ls.current_dims = idims.astype("int32") ls.left_edge = ls.global_startindex * ls.current_dx + self.ds.domain_left_edge.d ls.right_edge = ls.left_edge + ls.current_dims * ls.current_dx - ls.fields = [np.zeros(idims, dtype="float64") - 999 for field in fields] + ls.fields = [np.zeros(idims, dtype="float64") for field in fields] self._setup_data_source(ls) return ls