Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions firedrake/cython/dmcommon.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import math
import cython
import numpy as np
import firedrake
from collections.abc import Sequence
from firedrake.petsc import PETSc
from mpi4py import MPI
from firedrake.utils import IntType, ScalarType
Expand Down Expand Up @@ -2203,11 +2204,11 @@ def _get_expanded_dm_dg_coords(dm: PETSc.DM, ndofs: np.ndarray):

def _get_periodicity(dm: PETSc.DM) -> tuple[tuple[bool, bool], ...]:
"""Return mesh periodicity information.

This function returns a 2-tuple of bools per dimension where the first entry indicates
whether the mesh is periodic in that dimension, and the second indicates whether the
mesh is single-cell periodic in that dimension.

"""
cdef:
const PetscReal *maxCell, *L
Expand Down Expand Up @@ -3971,7 +3972,7 @@ def create_halo_exchange_sf(PETSc.DM dm):
def submesh_create(PETSc.DM dm,
PetscInt subdim,
label_name,
PetscInt label_value,
subdomain_id,
PetscBool ignore_label_halo,
comm=None):
"""Create submesh.
Expand All @@ -3984,8 +3985,8 @@ def submesh_create(PETSc.DM dm,
Topological dimension of the submesh
label_name : str
Name of the label
label_value : int
Value in the label
subdomain_id : int | Sequence
Values in the label
ignore_label_halo : bool
If labeled points in the halo are ignored.
comm : PETSc.Comm | None
Expand All @@ -3995,14 +3996,42 @@ def submesh_create(PETSc.DM dm,
cdef:
PETSc.DMLabel label, temp_label
char *temp_label_name = <char *>"firedrake_submesh_temp_label"
PetscInt pStart, pEnd, p, i, stratum_size
PetscInt pStart, pEnd, p, i, stratum_size, label_value, temp_label_value
PETSc.PetscIS stratum_is = NULL
const PetscInt *stratum_indices = NULL

# Parse string subdomain_id into int value
if isinstance(subdomain_id, str):
if subdomain_id == "on_boundary":
label_name = "exterior_facets"
subdomain_id = 1
else:
raise ValueError(f"Submesh construction got invalid subdomain_id {subdomain_id}.")
label = dm.getLabel(label_name)
# Create temp_label that contains no lower-dimensional points.
# Create a temporary label
dm.createLabel(temp_label_name)
temp_label = dm.getLabel(temp_label_name)
# Parse tuple subdomain_id into a single value
if isinstance(subdomain_id, Sequence):
# Take the union of the labels in the list
iset = PETSc.IS().createGeneral([], comm=dm.comm)
for sub in subdomain_id:
if isinstance(sub, str):
if sub == "on_boundary":
cur = dm.getStratumIS("exterior_facets", 1)
else:
raise ValueError(f"Submesh construction got invalid subdomain_id {sub}.")
else:
cur = label.getStratumIS(sub)
iset = iset.union(cur)
# Add marker in temp_label with the union
label_value = 1
temp_label.setStratumIS(label_value, iset)
# Pass on temp_label
label = temp_label
else:
label_value = subdomain_id
# Add a marker in temp_label that contains no lower-dimensional points.
temp_label_value = label_value + 1
CHKERR(DMLabelGetStratumSize(<DMLabel>label.dmlabel, label_value, &stratum_size))
if stratum_size > 0:
CHKERR(DMLabelGetStratumIS(<DMLabel>label.dmlabel, label_value, &stratum_is))
Expand All @@ -4013,12 +4042,12 @@ def submesh_create(PETSc.DM dm,
# Only include points on the submesh topological dimension,
# culling all lower-dimensional points.
if pStart <= p < pEnd:
CHKERR(DMLabelSetValue(<DMLabel>temp_label.dmlabel, p, label_value))
CHKERR(DMLabelSetValue(<DMLabel>temp_label.dmlabel, p, temp_label_value))
CHKERR(ISRestoreIndices(stratum_is, &stratum_indices))
CHKERR(ISDestroy(&stratum_is))
# Make submesh using temp_label.
subdm, ownership_transfer_sf = dm.filter(label=temp_label,
value=label_value,
value=temp_label_value,
ignoreHalo=ignore_label_halo,
sanitizeSubMesh=PETSC_TRUE,
comm=comm)
Expand Down
66 changes: 62 additions & 4 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -4881,10 +4881,11 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig
subdim : int | None
Topological dimension of the submesh.
Defaults to ``mesh.topological_dimension``.
subdomain_id : int | None
subdomain_id : int | Sequence | None
Subdomain ID representing the submesh.
If `None` the submesh will cover the entire domain.
This is useful to obtain a codim-1 submesh over all facets or
If multiple subdomain IDs are provided, their union is taken.
If `None` the submesh will cover the entire domain,
this is useful to obtain a codim-1 submesh over all facets or
a submesh over a different communicator.
label_name : str | None
Name of the label to search ``subdomain_id`` in.
Expand Down Expand Up @@ -4927,13 +4928,66 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig
ridges to be contained in the quad mesh are shared by at most two
facets to make the quad mesh orientation algorithm work.

Examples
--------
>>> mesh = UnitSquareMesh(4, 4)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we should be testing code snippets so they don't bitrot. That's a bit complicated here. Can you open an issue and assign me?

>>> x, y = SpatialCoordinate(mesh)
>>> DG = FunctionSpace(mesh, "DG", 0)
>>> DGT = FunctionSpace(mesh, "DGT", 0)

Mark a cell subdomain and construct a codim-0 submesh from all cells in the subdomain

>>> cell_marker = assemble(interpolate(conditional(lt(x, 0.5), 1, 0), DG))
>>> mesh.mark_entities(cell_marker, 111)
>>> submesh = Submesh(mesh, subdomain_id=111)

Mark a facet subdomain and construct a codim-1 submesh from all facets in the subdomain

>>> facet_marker = assemble(interpolate(conditional(lt(abs(x-0.5), 1E-12), 1, 0), DGT))
>>> mesh.mark_entities(facet_marker, 222)
>>> submesh = Submesh(mesh, subdim=mesh.topological_dimension-1, subdomain_id=222)

Construct a codim-0 submesh of the union of multiple subdomains by passing a list

>>> mesh.mark_entities(assemble(interpolate(conditional(lt(x, 0.5), 1, 0), DG)), 1)
>>> mesh.mark_entities(assemble(interpolate(conditional(lt(y, 0.5), 1, 0), DG)), 2)
>>> submesh = Submesh(mesh, subdomain_id=[1, 2])

Construct a codim-1 submesh of all the facets (the skeleton mesh)

>>> submesh = Submesh(mesh, subdim=1)

Construct a codim-1 submesh of the entire boundary

>>> submesh = Submesh(mesh, subdomain_id="on_boundary")

Construct a codim-1 submesh of the union of multiple boundaries

>>> submesh = Submesh(mesh, subdim=mesh.topological_dimension-1, subdomain_id=[1, 2, 3])

Construct a codim-0 submesh of the part of the mesh owned by each MPI rank

>>> submesh = Submesh(mesh, ignore_halo=True, comm=COMM_SELF)

"""
if not isinstance(mesh, MeshGeometry):
raise TypeError("Parent mesh must be a `MeshGeometry`")
if isinstance(mesh.topology, ExtrudedMeshTopology):
raise NotImplementedError("Can not create a submesh of an ``ExtrudedMesh``")
elif isinstance(mesh.topology, VertexOnlyMeshTopology):
raise NotImplementedError("Can not create a submesh of a ``VertexOnlyMesh``")

if subdomain_id == "on_boundary":
if subdim is None:
subdim = mesh.topological_dimension - 1
elif subdim != mesh.topological_dimension - 1:
raise ValueError('subdomain_id="on_boundary" requires subdim=dim-1')
if label_name is None:
label_name = "exterior_facets"
elif label_name != "exterior_facets":
raise ValueError('subdomain_id="on_boundary" requires label_name="exterior_facets"')
subdomain_id = 1

if subdim is None:
subdim = mesh.topological_dimension
plex = mesh.topology_dm
Expand All @@ -4959,7 +5013,7 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig
if subplex.getDimension() != subdim:
raise RuntimeError(f"Found subplex dim ({subplex.getDimension()}) != expected ({subdim})")
if reorder is None:
# Ideally we should set perm_is = mesh.dm_reordering[label_indices]
# Ideally we should set perm_is = mesh._dm_renumbering[label_indices]
reorder = mesh._did_reordering

submesh = Mesh(
Expand All @@ -4972,6 +5026,10 @@ def Submesh(mesh, subdim=None, subdomain_id=None, label_name=None, name=None, ig
)
# Tag the relabeled mesh with the original distribution parameters
submesh._distribution_parameters = mesh._distribution_parameters
# Store the construction parameters in case we need to reconstruct this Submesh
submesh._submesh_label_name = label_name
submesh._submesh_subdomain_id = subdomain_id
submesh._submesh_ignore_halo = ignore_halo
return submesh


Expand Down
45 changes: 45 additions & 0 deletions tests/firedrake/submesh/test_submesh_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest
import numpy as np
from firedrake import *


def test_submesh_subdomain_id_union():
mesh = UnitSquareMesh(4, 4)
x, y = SpatialCoordinate(mesh)
M = FunctionSpace(mesh, "DG", 0)
m1 = Function(M).interpolate(conditional(lt(x, 0.5), 1, 0))
m2 = Function(M).interpolate(conditional(lt(y, 0.5), 1, 0))
mesh.mark_entities(m1, 111)
mesh.mark_entities(m2, 222)

subdomain_id = [111, 222]
submesh1 = Submesh(mesh, mesh.topological_dimension, subdomain_id=subdomain_id)

m3 = Function(M).interpolate(m1 + m2 - m1 * m2)
expected = assemble(m3*dx)
assert abs(assemble(1*dx(domain=submesh1)) - expected) < 1E-12

mesh.mark_entities(m3, 333)
submesh2 = Submesh(mesh, mesh.topological_dimension, 333)
assert submesh2.cell_set.size == submesh1.cell_set.size
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)


@pytest.mark.parametrize("subdomain_id", ["on_boundary", (1, 3, 6)])
def test_submesh_facet_subdomain_id_union(subdomain_id):
mesh = UnitCubeMesh(2, 2, 2)
submesh1 = Submesh(mesh, mesh.topological_dimension - 1, subdomain_id=subdomain_id)
if subdomain_id == "on_boundary":
area = assemble(1*ds(domain=mesh))
else:
area = assemble(1*ds(subdomain_id, domain=mesh))
assert abs(assemble(1*dx(domain=submesh1)) - area) < 1E-12

DGT = FunctionSpace(mesh, "DGT", 0)
facet_function = Function(DGT)
DirichletBC(DGT, 1, subdomain_id).apply(facet_function)
facet_value = 999
rmesh = RelabeledMesh(mesh, [facet_function], [facet_value])
submesh2 = Submesh(rmesh, mesh.topological_dimension - 1, facet_value)
assert submesh2.cell_set.size == submesh1.cell_set.size
assert np.allclose(submesh2.coordinates.dat.data, submesh1.coordinates.dat.data)
Loading