Skip to content

Commit 9f26aab

Browse files
committed
Merge branch 'main' into pbrubeck/fix/integral_type_map
2 parents 0e20dbc + 56ddeed commit 9f26aab

File tree

12 files changed

+94
-63
lines changed

12 files changed

+94
-63
lines changed

firedrake/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def init_petsc():
6767
from firedrake.cofunction import Cofunction, RieszMap # noqa: F401
6868
from firedrake.constant import Constant # noqa: F401
6969
from firedrake.deflation import DeflatedSNES, Deflation # noqa: F401
70-
from firedrake.exceptions import ConvergenceError # noqa: F401
70+
from firedrake.exceptions import ConvergenceError, MismatchingDomainError # noqa: F401
7171
from firedrake.function import ( # noqa: F401
7272
Function, PointNotInDomainError,
7373
CoordinatelessFunction, PointEvaluator

firedrake/exceptions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from tsfc.exceptions import MismatchingDomainError # noqa: F401
12

23

34
class ConvergenceError(Exception):

firedrake/interpolation.py

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from tsfc.driver import compile_expression_dual_evaluation
3535
from tsfc.ufl_utils import extract_firedrake_constants, hash_expr
3636

37-
from firedrake.utils import IntType, ScalarType, known_pyop2_safe, tuplify
37+
from firedrake.utils import IntType, ScalarType, cached_property, known_pyop2_safe, tuplify
3838
from firedrake.tsfc_interface import extract_numbered_coefficients, _cachedir
3939
from firedrake.ufl_expr import Argument, Coargument, action
4040
from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology, MeshGeometry, MeshTopology, VertexOnlyMesh
@@ -155,6 +155,57 @@ def options(self) -> InterpolateOptions:
155155
"""
156156
return self._options
157157

158+
@cached_property
159+
def _interpolator(self):
160+
"""Access the numerical interpolator.
161+
162+
Returns
163+
-------
164+
Interpolator
165+
An appropriate :class:`Interpolator` subclass for this
166+
interpolation expression.
167+
"""
168+
arguments = self.arguments()
169+
has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments)
170+
if len(arguments) == 2 and has_mixed_arguments:
171+
return MixedInterpolator(self)
172+
173+
operand, = self.ufl_operands
174+
target_mesh = self.target_space.mesh()
175+
176+
try:
177+
source_mesh = extract_unique_domain(operand) or target_mesh
178+
except ValueError:
179+
raise NotImplementedError(
180+
"Interpolating an expression with no arguments defined on multiple meshes is not implemented yet."
181+
)
182+
183+
try:
184+
target_mesh = target_mesh.unique()
185+
source_mesh = source_mesh.unique()
186+
except RuntimeError:
187+
return MixedInterpolator(self)
188+
189+
submesh_interp_implemented = (
190+
all(isinstance(m.topology, MeshTopology) for m in [target_mesh, source_mesh])
191+
and target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1]
192+
and target_mesh.topological_dimension == source_mesh.topological_dimension
193+
)
194+
if target_mesh is source_mesh or submesh_interp_implemented:
195+
return SameMeshInterpolator(self)
196+
197+
if isinstance(target_mesh.topology, VertexOnlyMeshTopology):
198+
if isinstance(source_mesh.topology, VertexOnlyMeshTopology):
199+
return VomOntoVomInterpolator(self)
200+
if target_mesh.geometric_dimension != source_mesh.geometric_dimension:
201+
raise ValueError("Cannot interpolate onto a VertexOnlyMesh of a different geometric dimension.")
202+
return SameMeshInterpolator(self)
203+
204+
if has_mixed_arguments or len(self.target_space) > 1:
205+
return MixedInterpolator(self)
206+
207+
return CrossMeshInterpolator(self)
208+
158209

159210
@PETSc.Log.EventDecorator()
160211
def interpolate(expr: Expr, V: WithGeometry | BaseForm, **kwargs) -> Interpolate:
@@ -353,46 +404,7 @@ def get_interpolator(expr: Interpolate) -> Interpolator:
353404
An appropriate :class:`Interpolator` subclass for the given
354405
interpolation expression.
355406
"""
356-
arguments = expr.arguments()
357-
has_mixed_arguments = any(len(arg.function_space()) > 1 for arg in arguments)
358-
if len(arguments) == 2 and has_mixed_arguments:
359-
return MixedInterpolator(expr)
360-
361-
operand, = expr.ufl_operands
362-
target_mesh = expr.target_space.mesh()
363-
364-
try:
365-
source_mesh = extract_unique_domain(operand) or target_mesh
366-
except ValueError:
367-
raise NotImplementedError(
368-
"Interpolating an expression with no arguments defined on multiple meshes is not implemented yet."
369-
)
370-
371-
try:
372-
target_mesh = target_mesh.unique()
373-
source_mesh = source_mesh.unique()
374-
except RuntimeError:
375-
return MixedInterpolator(expr)
376-
377-
submesh_interp_implemented = (
378-
all(isinstance(m.topology, MeshTopology) for m in [target_mesh, source_mesh])
379-
and target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1]
380-
and target_mesh.topological_dimension == source_mesh.topological_dimension
381-
)
382-
if target_mesh is source_mesh or submesh_interp_implemented:
383-
return SameMeshInterpolator(expr)
384-
385-
if isinstance(target_mesh.topology, VertexOnlyMeshTopology):
386-
if isinstance(source_mesh.topology, VertexOnlyMeshTopology):
387-
return VomOntoVomInterpolator(expr)
388-
if target_mesh.geometric_dimension != source_mesh.geometric_dimension:
389-
raise ValueError("Cannot interpolate onto a VertexOnlyMesh of a different geometric dimension.")
390-
return SameMeshInterpolator(expr)
391-
392-
if has_mixed_arguments or len(expr.target_space) > 1:
393-
return MixedInterpolator(expr)
394-
395-
return CrossMeshInterpolator(expr)
407+
return expr._interpolator
396408

397409

398410
class DofNotDefinedError(Exception):

firedrake/mesh.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3123,12 +3123,11 @@ def curve_field(self, order, permutation_tol=1e-8, location_tol=1e-1, cg_field=F
31233123
pyop2_index.extend(cell_node_map.values[ngidx])
31243124

31253125
# Find the correct coordinate permutation for each cell
3126-
# NB: Coordinates must be cast to real when running Firedrake in complex mode
31273126
permutation = find_permutation(
31283127
physical_space_points,
3129-
new_coordinates.dat.data[pyop2_index].reshape(
3128+
new_coordinates.dat.data[pyop2_index].real.reshape(
31303129
physical_space_points.shape
3131-
).astype(np.float64, copy=False),
3130+
),
31323131
tol=permutation_tol
31333132
)
31343133

tests/firedrake/regression/test_interpolate.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,9 +592,12 @@ def test_interpolator_reuse(family, degree, mode):
592592
u = Function(V.dual())
593593
expr = interpolate(TestFunction(V), u)
594594

595-
I = get_interpolator(expr)
595+
Iorig = get_interpolator(expr)
596596

597597
for k in range(3):
598+
I = get_interpolator(expr)
599+
assert I is Iorig
600+
598601
u.assign(rg.uniform(u.function_space()))
599602
expected = u.dat.data.copy()
600603

tests/firedrake/regression/test_multiple_domains.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ def test_mismatching_meshes_indexed_function(mesh1, mesh3):
4141
with pytest.raises(NotImplementedError):
4242
project(d1, target)
4343

44-
with pytest.raises(ValueError):
44+
with pytest.raises(MismatchingDomainError):
4545
assemble(inner(d1, TestFunction(V2))*dx(domain=mesh3))
4646

47-
with pytest.raises(ValueError):
47+
with pytest.raises(MismatchingDomainError):
4848
assemble(inner(d1, TestFunction(V2))*dx(domain=mesh1))
4949

5050

@@ -177,29 +177,29 @@ def test_multi_domain_assemble():
177177

178178
for i, j in [(0, 1), (1, 0)]:
179179
a1 = inner(u[i], v[j])*dx(domain=mesh1)
180-
with pytest.raises(ValueError):
180+
with pytest.raises(MismatchingDomainError):
181181
assemble(a1)
182182
a2 = inner(u[i], v[j])*dx(domain=mesh2)
183-
with pytest.raises(ValueError):
183+
with pytest.raises(MismatchingDomainError):
184184
assemble(a2)
185185
l1 = inner(f[i], v[j])*dx(domain=mesh1)
186-
with pytest.raises(ValueError):
186+
with pytest.raises(MismatchingDomainError):
187187
assemble(l1)
188188
l2 = inner(f[i], v[j])*dx(domain=mesh2)
189-
with pytest.raises(ValueError):
189+
with pytest.raises(MismatchingDomainError):
190190
assemble(l2)
191191

192192
for i, j in [(0, 0), (1, 1)]:
193193
a = inner(u[i], v[j])*dx(domain=mesh1)
194194
if i == 1:
195-
with pytest.raises(ValueError):
195+
with pytest.raises(MismatchingDomainError):
196196
assemble(a)
197197
continue
198198
A = assemble(a)
199199
assert A.M.values.shape == (V.dim(), V.dim())
200200

201201
a = inner(u[0], v[0])*dx(domain=mesh1) + inner(u[0], v[1])*dx(domain=mesh2)
202-
with pytest.raises(ValueError):
202+
with pytest.raises(MismatchingDomainError):
203203
assemble(a)
204204

205205
a = inner(u[0], v[0])*dx(domain=mesh1) + inner(u[1], v[1])*dx(domain=mesh2)

tests/firedrake/regression/test_projection_zany.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def run_convergence_test(mh, el, degree, convrate):
8080
('Hermite', 3, 3.8),
8181
('Bell', 5, 4.7),
8282
('Argyris', 5, 5.8),
83-
('Argyris', 6, 6.7)])
83+
('Argyris', 6, 6.7),
84+
('Nonconforming Robust Wu-Xu', 7, 3.8)])
8485
def test_projection_zany_convergence_2d(hierarchy_2d, el, deg, convrate):
8586
run_convergence_test(hierarchy_2d[2:], el, deg, convrate)
8687

@@ -97,7 +98,9 @@ def test_projection_zany_convergence_3d(hierarchy_3d, el, deg, convrate):
9798
('HCT', 3),
9899
('HCT', 4),
99100
('Argyris', 5),
100-
('Argyris', 6)])
101+
('Argyris', 6),
102+
('Nonconforming Wu-Xu', 4),
103+
('Alfeld C2', 5)])
101104
def test_mass_conditioning(element, degree, hierarchy_2d):
102105
mass_cond = []
103106
for msh in hierarchy_2d[1:4]:

tsfc/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from tsfc.driver import compile_form, compile_expression_dual_evaluation # noqa: F401
22
from tsfc.parameters import default_parameters # noqa: F401
3+
from tsfc.exceptions import MismatchingDomainError # noqa: F401
34

45

56
def register_citations():

tsfc/driver.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from tsfc.parameters import default_parameters, is_complex
2424
from tsfc.ufl_utils import apply_mapping, extract_firedrake_constants
2525
import tsfc.kernel_interface.firedrake_loopy as firedrake_interface_loopy
26+
from tsfc.exceptions import MismatchingDomainError
27+
2628

2729
# To handle big forms. The various transformations might need a deeper stack
2830
sys.setrecursionlimit(3000)
@@ -147,7 +149,7 @@ def compile_integral(integral_data, form_data, prefix, parameters, *, diagonal=F
147149
integral_type=integral_data.integral_type,
148150
subdomain_id=integral_data.subdomain_id,
149151
domain_number=domain_number,
150-
domain_integral_type_map=integral_data.domain_integral_type_map,
152+
domain_integral_type_map={mesh: integral_data.domain_integral_type_map.get(mesh, None) for mesh in all_meshes},
151153
arguments=arguments,
152154
coefficients=coefficients,
153155
coefficient_split=coefficient_split,
@@ -188,19 +190,19 @@ def validate_domains(form):
188190
domain = itg.ufl_domain()
189191
for other_domain in itg.extra_domain_integral_type_map():
190192
if domain.submesh_youngest_common_ancester(other_domain) is None:
191-
raise ValueError("Assembly of forms over unrelated meshes is not supported. "
192-
"Try using Submeshes or cross-mesh interpolation.")
193+
raise MismatchingDomainError("Assembly of forms over unrelated meshes is not supported. "
194+
"Try using Submeshes or cross-mesh interpolation.")
193195

194196
# Check that all Arguments and Coefficients are defined on the valid domains
195197
valid_domains = set(itg.extra_domain_integral_type_map())
196198
valid_domains.add(domain)
197199

198200
itg_domains = set(extract_domains(itg))
199201
if len(itg_domains - valid_domains) > 0:
200-
raise ValueError("Argument or Coefficient domain not found in integral. "
201-
"Possibly, the form contains coefficients on different meshes "
202-
"and requires measure intersection, for example: "
203-
'Measure("dx", argument_mesh, intersect_measures=[Measure("dx", coefficient_mesh)]).')
202+
raise MismatchingDomainError("Argument or Coefficient domain not found in integral. "
203+
"Possibly, the form contains coefficients on different meshes "
204+
"and requires measure intersection, for example: "
205+
'Measure("dx", argument_mesh, intersect_measures=[Measure("dx", coefficient_mesh)]).')
204206

205207

206208
def preprocess_parameters(parameters):

tsfc/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
3+
class MismatchingDomainError(Exception):
4+
"""Error raised for unsupported multidomain problems"""

0 commit comments

Comments
 (0)