Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions yt/frontends/amrvac/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from more_itertools import always_iterable

from yt.config import ytcfg
from yt.data_objects.index_subobjects.grid_patch import AMRGridPatch
from yt.data_objects.index_subobjects.stretched_grid import StretchedGrid
from yt.data_objects.static_output import Dataset
from yt.funcs import mylog, setdefaultattr
from yt.geometry.api import Geometry
Expand Down Expand Up @@ -52,17 +52,18 @@ def _parse_geometry(geometry_tag: str) -> Geometry:
return Geometry(geometry_str.lower())


class AMRVACGrid(AMRGridPatch):
class AMRVACGrid(StretchedGrid):
"""A class to populate AMRVACHierarchy.grids, setting parent/children relations."""

_id_offset = 0

def __init__(self, id, index, level):
def __init__(self, id, cell_widths, filename, index, level, dims):
# <level> should use yt's convention (start from 0)
super().__init__(id, filename=index.index_filename, index=index)
super().__init__(id=id, filename=filename, index=index, cell_widths=cell_widths)
self.Parent = None
self.Children = []
self.Level = level
self.ActiveDimensions = dims

def get_global_startindex(self):
"""Refresh and retrieve the starting index for each dimension at current level.
Expand Down Expand Up @@ -142,6 +143,40 @@ def _parse_index(self):
dim = self.dataset.dimensionality

self.grids = np.empty(self.num_grids, dtype="object")
meshlist = self.ds.namelist["meshlist"]
if (stretch_dim := meshlist.get("stretch_dim")) is not None:
assert isinstance(stretch_dim, list)
assert len(stretch_dim) >= self.ds.dimensionality
stretch_baselevel = meshlist.get("qstretch_baselevel")
if "qstretch_baselevel" not in meshlist:
# compute default values dynamically, just as done in AMRVAC
stretched_dims = [bool(k) for k in stretch_dim]
assert sum(stretched_dims) == 1 # exactly one stretched direction
stretched_dim = stretched_dims.index(True)
_sbl = [
1.0,
] * self.ds.dimensionality
_sbl[stretched_dim] = (
meshlist[f"xprobmax{stretched_dim + 1}"]
/ meshlist[f"xprobmin{stretched_dim + 1}"]
) ** (1.0 / meshlist[f"domain_nx{stretched_dim + 1}"])
stretch_baselevel = tuple(_sbl)
elif isinstance(stretch_baselevel := meshlist["qstretch_baselevel"], list):
assert len(stretch_baselevel) >= self.ds.dimensionality
stretch_baselevel = (
float(b) for b in stretch_baselevel[: self.ds.dimensionality]
)
else:
assert isinstance(stretch_baselevel, float | int)
stretched_dims = [bool(k) for k in stretch_dim]
assert sum(stretched_dims) == 1 # exactly one stretched direction
stretched_dim = stretched_dims.index(True)
_sbl = [
1.0,
] * self.ds.dimensionality
_sbl[stretched_dim] = stretch_baselevel
stretch_baselevel = tuple(_sbl)

for igrid, (ytlevel, morton_index) in enumerate(
zip(ytlevels, morton_indices, strict=True)
):
Expand All @@ -152,7 +187,14 @@ def _parse_index(self):
self.grid_left_edge[igrid, :dim] = left_edge
self.grid_right_edge[igrid, :dim] = left_edge + block_nx * dx
self.grid_dimensions[igrid, :dim] = block_nx
self.grids[igrid] = self.grid(igrid, self, ytlevels[igrid])
self.grids[igrid] = self.grid(
id=igrid,
index=self,
level=ytlevels[igrid],
filename=self.index_filename,
cell_widths=_cell_widths,
dims=self.grid_dimensions[igrid],
)

def _populate_grid_objects(self):
# required method
Expand Down
Loading