From b170a3261e318c5bc3d9efb2c240efe4dc617935 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 7 May 2026 14:47:38 +0100 Subject: [PATCH 1/4] Submesh: support tuple subdomain_id --- firedrake/cython/dmcommon.pyx | 45 +++++++++++-- firedrake/mesh.py | 66 +++++++++++++++++-- .../submesh/test_submesh_interface.py | 45 +++++++++++++ 3 files changed, 146 insertions(+), 10 deletions(-) create mode 100644 tests/firedrake/submesh/test_submesh_interface.py diff --git a/firedrake/cython/dmcommon.pyx b/firedrake/cython/dmcommon.pyx index a4d4d92460..4c2b3fdd60 100644 --- a/firedrake/cython/dmcommon.pyx +++ b/firedrake/cython/dmcommon.pyx @@ -7,6 +7,7 @@ import math import cython import numpy as np import firedrake +from collections.abc import Sequence from firedrake.petsc import PETSc from mpi4py import MPI from firedrake.utils import IntType, ScalarType @@ -2203,11 +2204,11 @@ def _get_expanded_dm_dg_coords(dm: PETSc.DM, ndofs: np.ndarray): def _get_periodicity(dm: PETSc.DM) -> tuple[tuple[bool, bool], ...]: """Return mesh periodicity information. - + This function returns a 2-tuple of bools per dimension where the first entry indicates whether the mesh is periodic in that dimension, and the second indicates whether the mesh is single-cell periodic in that dimension. - + """ cdef: const PetscReal *maxCell, *L @@ -3971,7 +3972,7 @@ def create_halo_exchange_sf(PETSc.DM dm): def submesh_create(PETSc.DM dm, PetscInt subdim, label_name, - PetscInt label_value, + subdomain_id, PetscBool ignore_label_halo, comm=None): """Create submesh. @@ -3984,8 +3985,8 @@ def submesh_create(PETSc.DM dm, Topological dimension of the submesh label_name : str Name of the label - label_value : int - Value in the label + subdomain_id : int | Sequence + Values in the label ignore_label_halo : bool If labeled points in the halo are ignored. comm : PETSc.Comm | None @@ -3995,10 +3996,42 @@ def submesh_create(PETSc.DM dm, cdef: PETSc.DMLabel label, temp_label char *temp_label_name = "firedrake_submesh_temp_label" - PetscInt pStart, pEnd, p, i, stratum_size + PetscInt pStart, pEnd, p, i, stratum_size, label_value PETSc.PetscIS stratum_is = NULL const PetscInt *stratum_indices = NULL + # Parse non-integer subdomain_id into a single value + if isinstance(subdomain_id, str): + if subdomain_id == "on_boundary": + label_name = "exterior_facets" + label_value = 1 + else: + raise ValueError(f"Submesh construction got invalid subdomain_id {subdomain_id}.") + elif isinstance(subdomain_id, Sequence): + label = dm.getLabel(label_name) + + def get_label_points(sub): + if sub == "on_boundary": + return dm.getStratumIS("exterior_facets", 1) + elif isinstance(sub, str): + raise ValueError(f"Submesh construction got invalid subdomain_id {sub}.") + else: + return label.getStratumIS(sub) + + # Take the union of the labels in the list + iset = PETSc.IS().createGeneral([], comm=dm.comm) + for sub in subdomain_id: + iset = iset.union(get_label_points(sub)) + + # Create a temporary label + label_name = "firedrake_composite_subdomain_label" + dm.createLabel(label_name) + label_value = 1 + dm.getLabelIdIS(label_name).getSize() + label = dm.getLabel(label_name) + label.setStratumIS(label_value, iset) + else: + label_value = subdomain_id + label = dm.getLabel(label_name) # Create temp_label that contains no lower-dimensional points. dm.createLabel(temp_label_name) diff --git a/firedrake/mesh.py b/firedrake/mesh.py index 771462dc25..dd9cc0fa2a 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -4881,10 +4881,11 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig subdim : int | None Topological dimension of the submesh. Defaults to ``mesh.topological_dimension``. - subdomain_id : int | None + subdomain_id : int | Sequence | None Subdomain ID representing the submesh. - If `None` the submesh will cover the entire domain. - This is useful to obtain a codim-1 submesh over all facets or + If multiple subdomain IDs are provided, their union is taken. + If `None` the submesh will cover the entire domain, + this is useful to obtain a codim-1 submesh over all facets or a submesh over a different communicator. label_name : str | None Name of the label to search ``subdomain_id`` in. @@ -4927,6 +4928,47 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig ridges to be contained in the quad mesh are shared by at most two facets to make the quad mesh orientation algorithm work. + Examples + -------- + >>> mesh = UnitSquareMesh(4, 4) + >>> x, y = SpatialCoordinate(mesh) + >>> DG = FunctionSpace(mesh, "DG", 0) + >>> DGT = FunctionSpace(mesh, "DGT", 0) + + Mark a cell subdomain and construct a codim-0 submesh from all cells in the subdomain + + >>> cell_marker = assemble(interpolate(conditional(lt(x, 0.5), 1, 0), DG)) + >>> mesh.mark_entities(cell_marker, 111) + >>> submesh = Submesh(mesh, subdomain_id=111) + + Mark a facet subdomain and construct a codim-1 submesh from all facets in the subdomain + + >>> facet_marker = assemble(interpolate(conditional(lt(abs(x-0.5), 1E-12), 1, 0), DGT)) + >>> mesh.mark_entities(facet_marker, 222) + >>> submesh = Submesh(mesh, subdim=mesh.topological_dimension-1, subdomain_id=222) + + Construct a codim-0 submesh of the union of multiple subdomains by passing a list + + >>> mesh.mark_entities(assemble(interpolate(conditional(lt(x, 0.5), 1, 0), DG)), 1) + >>> mesh.mark_entities(assemble(interpolate(conditional(lt(y, 0.5), 1, 0), DG)), 2) + >>> submesh = Submesh(mesh, subdomain_id=[1, 2]) + + Construct a codim-1 submesh of all the facets (the skeleton mesh) + + >>> submesh = Submesh(mesh, subdim=1) + + Construct a codim-1 submesh of the entire boundary + + >>> submesh = Submesh(mesh, subdomain_id="on_boundary") + + Construct a codim-1 submesh of the union of multiple boundaries + + >>> submesh = Submesh(mesh, subdim=mesh.topological_dimension-1, subdomain_id=[1, 2, 3]) + + Construct a codim-0 submesh of the part of the mesh owned by each MPI rank + + >>> submesh = Submesh(mesh, ignore_halo=True, comm=COMM_SELF) + """ if not isinstance(mesh, MeshGeometry): raise TypeError("Parent mesh must be a `MeshGeometry`") @@ -4934,6 +4976,18 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig raise NotImplementedError("Can not create a submesh of an ``ExtrudedMesh``") elif isinstance(mesh.topology, VertexOnlyMeshTopology): raise NotImplementedError("Can not create a submesh of a ``VertexOnlyMesh``") + + if subdomain_id == "on_boundary": + if subdim is None: + subdim = mesh.topological_dimension - 1 + elif subdim != mesh.topological_dimension - 1: + raise ValueError('subdomain_id="on_boundary" requires subdim=dim-1') + if label_name is None: + label_name = "exterior_facets" + elif label_name != "exterior_facets": + raise ValueError('subdomain_id="on_boundary" requires label_name="exterior_facets"') + subdomain_id = 1 + if subdim is None: subdim = mesh.topological_dimension plex = mesh.topology_dm @@ -4959,7 +5013,7 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig if subplex.getDimension() != subdim: raise RuntimeError(f"Found subplex dim ({subplex.getDimension()}) != expected ({subdim})") if reorder is None: - # Ideally we should set perm_is = mesh.dm_reordering[label_indices] + # Ideally we should set perm_is = mesh._dm_renumbering[label_indices] reorder = mesh._did_reordering submesh = Mesh( @@ -4972,6 +5026,10 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig ) # Tag the relabeled mesh with the original distribution parameters submesh._distribution_parameters = mesh._distribution_parameters + # Store the construction parameters in case we need to reconstruct this Submesh + submesh._submesh_label_name = label_name + submesh._submesh_subdomain_id = subdomain_id + submesh._submesh_ignore_halo = ignore_halo return submesh diff --git a/tests/firedrake/submesh/test_submesh_interface.py b/tests/firedrake/submesh/test_submesh_interface.py new file mode 100644 index 0000000000..394738fab8 --- /dev/null +++ b/tests/firedrake/submesh/test_submesh_interface.py @@ -0,0 +1,45 @@ +import pytest +import numpy as np +from firedrake import * + + +def test_submesh_subdomain_id_union(): + mesh = UnitSquareMesh(4, 4) + x, y = SpatialCoordinate(mesh) + M = FunctionSpace(mesh, "DG", 0) + m1 = Function(M).interpolate(conditional(lt(x, 0.5), 1, 0)) + m2 = Function(M).interpolate(conditional(lt(y, 0.5), 1, 0)) + mesh.mark_entities(m1, 111) + mesh.mark_entities(m2, 222) + + subdomain_id = [111, 222] + submesh1 = Submesh(mesh, mesh.topological_dimension, subdomain_id=subdomain_id) + + m3 = Function(M).interpolate(m1 + m2 - m1 * m2) + expected = assemble(m3*dx) + assert abs(assemble(1*dx(domain=submesh1)) - expected) < 1E-12 + + mesh.mark_entities(m3, 333) + submesh2 = Submesh(mesh, mesh.topological_dimension, 333) + assert submesh2.cell_set.size == submesh1.cell_set.size + assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data) + + +@pytest.mark.parametrize("subdomain_id", ["on_boundary", (1, 3, 6)]) +def test_submesh_facet_subdomain_id_union(subdomain_id): + mesh = UnitCubeMesh(2, 2, 2) + submesh1 = Submesh(mesh, mesh.topological_dimension - 1, subdomain_id=subdomain_id) + if subdomain_id == "on_boundary": + area = assemble(1*ds(domain=mesh)) + else: + area = assemble(1*ds(subdomain_id, domain=mesh)) + assert abs(assemble(1*dx(domain=submesh1)) - area) < 1E-12 + + DGT = FunctionSpace(mesh, "DGT", 0) + facet_function = Function(DGT) + DirichletBC(DGT, 1, subdomain_id).apply(facet_function) + facet_value = 999 + rmesh = RelabeledMesh(mesh, [facet_function], [facet_value]) + submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value) + assert submesh2.cell_set.size == submesh1.cell_set.size + assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data) From a2696d4db44d53796269e33c8d42298b4f4fa2cb Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 8 May 2026 18:30:24 +0100 Subject: [PATCH 2/4] review suggestions --- firedrake/cython/dmcommon.pyx | 60 ++++++++++++----------------------- 1 file changed, 20 insertions(+), 40 deletions(-) diff --git a/firedrake/cython/dmcommon.pyx b/firedrake/cython/dmcommon.pyx index 4c2b3fdd60..f8e8ac953f 100644 --- a/firedrake/cython/dmcommon.pyx +++ b/firedrake/cython/dmcommon.pyx @@ -7,13 +7,14 @@ import math import cython import numpy as np import firedrake -from collections.abc import Sequence from firedrake.petsc import PETSc from mpi4py import MPI from firedrake.utils import IntType, ScalarType from libc.string cimport memset from libc.stdlib cimport qsort from finat.element_factory import as_fiat_cell +from numbers import Integral +from collections.abc import Sequence cimport numpy as np cimport mpi4py.MPI as MPI @@ -3994,52 +3995,32 @@ def submesh_create(PETSc.DM dm, """ cdef: + PETSc.IS points, subpoints PETSc.DMLabel label, temp_label char *temp_label_name = "firedrake_submesh_temp_label" PetscInt pStart, pEnd, p, i, stratum_size, label_value - PETSc.PetscIS stratum_is = NULL const PetscInt *stratum_indices = NULL - - # Parse non-integer subdomain_id into a single value - if isinstance(subdomain_id, str): - if subdomain_id == "on_boundary": - label_name = "exterior_facets" - label_value = 1 - else: - raise ValueError(f"Submesh construction got invalid subdomain_id {subdomain_id}.") - elif isinstance(subdomain_id, Sequence): - label = dm.getLabel(label_name) - - def get_label_points(sub): - if sub == "on_boundary": - return dm.getStratumIS("exterior_facets", 1) - elif isinstance(sub, str): - raise ValueError(f"Submesh construction got invalid subdomain_id {sub}.") - else: - return label.getStratumIS(sub) - - # Take the union of the labels in the list - iset = PETSc.IS().createGeneral([], comm=dm.comm) - for sub in subdomain_id: - iset = iset.union(get_label_points(sub)) - - # Create a temporary label - label_name = "firedrake_composite_subdomain_label" - dm.createLabel(label_name) - label_value = 1 + dm.getLabelIdIS(label_name).getSize() - label = dm.getLabel(label_name) - label.setStratumIS(label_value, iset) - else: - label_value = subdomain_id - + # Cast subdomain_id into an iterable + if isinstance(subdomain_id, str) or not isinstance(subdomain_id, Sequence): + subdomain_id = (subdomain_id,) + # Take the union of the all the label values label = dm.getLabel(label_name) + points = PETSc.IS().createGeneral([], comm=dm.comm) + for sub in subdomain_id: + if isinstance(sub, Integral): + subpoints = label.getStratumIS(sub) + elif sub == "on_boundary": + subpoints = dm.getStratumIS("exterior_facets", 1) + else: + raise ValueError(f"Submesh construction got invalid subdomain_id {sub}.") + points = points.union(subpoints) # Create temp_label that contains no lower-dimensional points. dm.createLabel(temp_label_name) temp_label = dm.getLabel(temp_label_name) - CHKERR(DMLabelGetStratumSize(label.dmlabel, label_value, &stratum_size)) + label_value = 1 + CHKERR(ISGetSize(points.iset, &stratum_size)) if stratum_size > 0: - CHKERR(DMLabelGetStratumIS(label.dmlabel, label_value, &stratum_is)) - CHKERR(ISGetIndices(stratum_is, &stratum_indices)) + CHKERR(ISGetIndices(points.iset, &stratum_indices)) CHKERR(DMPlexGetDepthStratum(dm.dm, subdim, &pStart, &pEnd)) for i in range(stratum_size): p = stratum_indices[i] @@ -4047,8 +4028,7 @@ def submesh_create(PETSc.DM dm, # culling all lower-dimensional points. if pStart <= p < pEnd: CHKERR(DMLabelSetValue(temp_label.dmlabel, p, label_value)) - CHKERR(ISRestoreIndices(stratum_is, &stratum_indices)) - CHKERR(ISDestroy(&stratum_is)) + CHKERR(ISRestoreIndices(points.iset, &stratum_indices)) # Make submesh using temp_label. subdm, ownership_transfer_sf = dm.filter(label=temp_label, value=label_value, From 55d625126619fef9d392896ee772183cfa7d2426 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sat, 9 May 2026 10:46:28 +0100 Subject: [PATCH 3/4] tidy --- firedrake/mesh.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/firedrake/mesh.py b/firedrake/mesh.py index 947d359c0a..942eb1c527 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -5029,10 +5029,6 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig ) # Tag the relabeled mesh with the original distribution parameters submesh._distribution_parameters = mesh._distribution_parameters - # Store the construction parameters in case we need to reconstruct this Submesh - submesh._submesh_label_name = label_name - submesh._submesh_subdomain_id = subdomain_id - submesh._submesh_ignore_halo = ignore_halo return submesh From ddf9137848fd43f90b537fc21bc39184496f170e Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sat, 9 May 2026 12:23:25 +0100 Subject: [PATCH 4/4] fixes --- firedrake/cython/dmcommon.pyx | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/firedrake/cython/dmcommon.pyx b/firedrake/cython/dmcommon.pyx index f8e8ac953f..785f91c51b 100644 --- a/firedrake/cython/dmcommon.pyx +++ b/firedrake/cython/dmcommon.pyx @@ -3998,14 +3998,15 @@ def submesh_create(PETSc.DM dm, PETSc.IS points, subpoints PETSc.DMLabel label, temp_label char *temp_label_name = "firedrake_submesh_temp_label" - PetscInt pStart, pEnd, p, i, stratum_size, label_value + PetscInt pStart, pEnd, p, i, stratum_size = 0, label_value = 1 const PetscInt *stratum_indices = NULL + # Cast subdomain_id into an iterable if isinstance(subdomain_id, str) or not isinstance(subdomain_id, Sequence): subdomain_id = (subdomain_id,) # Take the union of the all the label values label = dm.getLabel(label_name) - points = PETSc.IS().createGeneral([], comm=dm.comm) + points = PETSc.IS() for sub in subdomain_id: if isinstance(sub, Integral): subpoints = label.getStratumIS(sub) @@ -4013,12 +4014,15 @@ def submesh_create(PETSc.DM dm, subpoints = dm.getStratumIS("exterior_facets", 1) else: raise ValueError(f"Submesh construction got invalid subdomain_id {sub}.") - points = points.union(subpoints) + if points: + points = points.union(subpoints) + else: + points = subpoints # Create temp_label that contains no lower-dimensional points. dm.createLabel(temp_label_name) temp_label = dm.getLabel(temp_label_name) - label_value = 1 - CHKERR(ISGetSize(points.iset, &stratum_size)) + if points: + CHKERR(ISGetSize(points.iset, &stratum_size)) if stratum_size > 0: CHKERR(ISGetIndices(points.iset, &stratum_indices)) CHKERR(DMPlexGetDepthStratum(dm.dm, subdim, &pStart, &pEnd))