Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,13 @@
CodeGenTestSettings: Final[dict[str, dict[str, list[str]]]] = {
"internal": {"extras": ["jax"], "markers": ["not requires_dace"]}
}
CodeGenDaceTestSettings = CodeGenTestSettings | {
"dace": {"extras": [], "markers": ["requires_dace"]},
# Use dace-cartesian group to select the appropriate dace version
CodeGenCartesianTestSettings = CodeGenTestSettings | {
"dace": {"extras": [], "groups": ["dace-cartesian"], "markers": ["requires_dace"]},
}
# Install dace-next group to select the appropriate dace version
CodeGenNextTestSettings = CodeGenTestSettings | {
"dace": {"extras": [], "groups": ["dace-next"], "markers": ["requires_dace"]},
}


Expand Down Expand Up @@ -162,7 +167,7 @@ def test_cartesian(
) -> None:
"""Run selected 'gt4py.cartesian' tests."""

codegen_settings = CodeGenDaceTestSettings[codegen]
codegen_settings = CodeGenCartesianTestSettings[codegen]
device_settings = DeviceTestSettings[device]
extras = [
"standard",
Expand Down Expand Up @@ -245,7 +250,7 @@ def test_next(
) -> None:
"""Run selected 'gt4py.next' tests."""

codegen_settings = CodeGenDaceTestSettings[codegen]
codegen_settings = CodeGenNextTestSettings[codegen]
device_settings = DeviceTestSettings[device]
extras = [
"standard",
Expand Down
16 changes: 14 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ requires = ['cython>=3.0.0', 'setuptools>=77.0.3', 'versioningit>=3.1.1', 'wheel
# -- Dependency groups --
[dependency-groups]
build = ['cython>=3.0.0', 'pip>=22.1.1', 'setuptools>=77.0.3', 'wheel>=0.33.6']
dace-cartesian = [
'dace>=1.0.2' # refined in [tool.uv.sources]
]
dace-next = [
'dace==43!2026.04.27' # uses custom index at 'https://github.com/GridTools/pypi'
]
dev = [
{include-group = 'build'},
{include-group = 'docs'},
Expand Down Expand Up @@ -100,7 +106,6 @@ dependencies = [
'click>=8.0.0',
'cmake>=3.22',
'cytoolz>=1.0.1',
'dace>=2.0.0a3',
'deepdiff>=8.1.0',
'devtools>=0.6',
'factory-boy>=3.3.3',
Expand Down Expand Up @@ -457,6 +462,10 @@ conflicts = [
{extra = 'jax-cuda13'},
{extra = 'rocm6'},
{extra = 'rocm7'}
],
[
{group = 'dace-cartesian'},
{group = 'dace-next'}
]
]
default-groups = ["dev"]
Expand All @@ -473,9 +482,12 @@ name = 'gridtools'
url = 'https://gridtools.github.io/pypi/'

# Add the uv source below to pull dace from the gridtools index instead of PyPI:
# dace = {index = "gridtools"}
[tool.uv.sources]
atlas4py = {index = "test.pypi"}
dace = [
{git = "https://github.com/romanc/dace", branch = "romanc/math-functions", group = "dace-cartesian"},
{index = "gridtools", group = "dace-next"}
]

# -- versioningit --
[tool.versioningit]
Expand Down
40 changes: 33 additions & 7 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def _sdfg_add_arrays_and_edges(
inputs: set[str] | dict[str, dtypes.typeclass],
outputs: set[str] | dict[str, dtypes.typeclass],
origins: dict[str, tuple[int, ...]],
domain: tuple[int, ...],
) -> None:
for name, array in inner_sdfg.arrays.items():
if array.transient:
Expand Down Expand Up @@ -129,12 +130,20 @@ def _sdfg_add_arrays_and_edges(
if axis not in axes:
continue
o = origin[index]
e = field_info[name].boundary.lower_indices[cartesian_index]
lower, upper = field_info[name].boundary[cartesian_index]
s = inner_sdfg.arrays[name].shape[index]
ranges.append(
# s - 1 because ranges are inclusive
(o - max(0, e), o - max(0, e) + s - 1, 1)
)
if axis == CartesianSpace.Axis.K.name:
d = domain[cartesian_index]
ranges.append(
# max(0, lower) because ...
# d - 1 because ranges are inclusive
(o - max(0, lower), o + upper + d - 1, 1)
)
else:
ranges.append(
# s - 1 because ranges are inclusive
(o - max(0, lower), o - max(0, lower) + s - 1, 1)
)
index += 1

# Add data dimensions to the range
Expand Down Expand Up @@ -264,7 +273,7 @@ def freeze_origin_domain_sdfg(
nsdfg = state.add_nested_sdfg(inner_sdfg, inputs, outputs)

_sdfg_add_arrays_and_edges(
field_info, wrapper_sdfg, state, inner_sdfg, nsdfg, inputs, outputs, origin
field_info, wrapper_sdfg, state, inner_sdfg, nsdfg, inputs, outputs, origin, domain
)

# in special case of empty domain, remove entire SDFG.
Expand Down Expand Up @@ -920,7 +929,7 @@ def generate_extension(self) -> None:

@register
class DaceGPUBackend(BaseDaceBackend):
"""DaCe python backend using gt4py.cartesian.gtc."""
"""GPU DaCe python with an optimal KJI loop layout"""

name = "dace:gpu"
languages: ClassVar[dict] = {"computation": "cuda", "bindings": ["python"]}
Expand All @@ -933,3 +942,20 @@ class DaceGPUBackend(BaseDaceBackend):

def generate_extension(self) -> None:
return self.make_extension(uses_cuda=True)


@register
class DaceGPUBackendIJK(BaseDaceBackend):
"""GPU DaCe python with an optimal IJK loop layout"""

name = "dace:gpu_IJK"
languages: ClassVar[dict] = {"computation": "cuda", "bindings": ["python"]}
storage_info: ClassVar[layout.LayoutInfo] = layout_registry.from_name(name)
MODULE_GENERATOR_CLASS = DaCeCUDAPyExtModuleGenerator
options: ClassVar[GTBackendOptions] = {
**BaseGTBackend.GT_BACKEND_OPTS,
"device_sync": {"versioning": True, "type": bool},
}

def generate_extension(self) -> None:
return self.make_extension(uses_cuda=True)
28 changes: 14 additions & 14 deletions src/gt4py/cartesian/gtc/dace/oir_to_tasklet.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,38 +226,38 @@ def visit_NativeFunction(self, node: common.NativeFunction, **_kwargs: Any) -> s
common.NativeFunction.ABS: "abs",
common.NativeFunction.MIN: "min",
common.NativeFunction.MAX: "max",
common.NativeFunction.MOD: "fmod",
common.NativeFunction.MOD: "dace.math.fmod",
common.NativeFunction.SIN: "dace.math.sin",
common.NativeFunction.COS: "dace.math.cos",
common.NativeFunction.TAN: "dace.math.tan",
common.NativeFunction.ARCSIN: "asin",
common.NativeFunction.ARCCOS: "acos",
common.NativeFunction.ARCTAN: "atan",
common.NativeFunction.ARCSIN: "dace.math.asin",
common.NativeFunction.ARCCOS: "dace.math.acos",
common.NativeFunction.ARCTAN: "dace.math.atan",
common.NativeFunction.SINH: "dace.math.sinh",
common.NativeFunction.COSH: "dace.math.cosh",
common.NativeFunction.TANH: "dace.math.tanh",
common.NativeFunction.ARCSINH: "asinh",
common.NativeFunction.ARCCOSH: "acosh",
common.NativeFunction.ARCTANH: "atanh",
common.NativeFunction.ARCSINH: "dace.math.asinh",
common.NativeFunction.ARCCOSH: "dace.math.acosh",
common.NativeFunction.ARCTANH: "dace.math.atanh",
common.NativeFunction.SQRT: "dace.math.sqrt",
common.NativeFunction.POW: "dace.math.pow",
common.NativeFunction.EXP: "dace.math.exp",
common.NativeFunction.LOG: "dace.math.log",
common.NativeFunction.LOG10: "log10",
common.NativeFunction.GAMMA: "tgamma",
common.NativeFunction.CBRT: "cbrt",
common.NativeFunction.LOG10: "dace.math.log10",
common.NativeFunction.GAMMA: "dace.math.tgamma",
common.NativeFunction.CBRT: "dace.math.cbrt",
common.NativeFunction.ISFINITE: "isfinite",
common.NativeFunction.ISINF: "isinf",
common.NativeFunction.ISNAN: "isnan",
common.NativeFunction.FLOOR: "dace.math.ifloor",
common.NativeFunction.CEIL: "ceil",
common.NativeFunction.TRUNC: "trunc",
common.NativeFunction.CEIL: "dace.math.ceil",
common.NativeFunction.TRUNC: "dace.math.trunc",
common.NativeFunction.INT32: "dace.int32",
common.NativeFunction.INT64: "dace.int64",
common.NativeFunction.FLOAT32: "dace.float32",
common.NativeFunction.FLOAT64: "dace.float64",
common.NativeFunction.ERF: "erf",
common.NativeFunction.ERFC: "erfc",
common.NativeFunction.ERF: "dace.math.erf",
common.NativeFunction.ERFC: "dace.math.erfc",
common.NativeFunction.ROUND: "nearbyint",
common.NativeFunction.ROUND_AWAY_FROM_ZERO: "round",
}
Expand Down
4 changes: 1 addition & 3 deletions src/gt4py/cartesian/gtc/dace/oir_to_treeir.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,6 @@ def visit_Stencil(self, node: oir.Stencil) -> tir.TreeRoot:
param,
field_without_mask_extents[param.name],
k_bound,
symbols,
),
strides=get_dace_strides(param, symbols),
storage=DEFAULT_STORAGE_TYPE[self._device_type],
Expand All @@ -374,7 +373,7 @@ def visit_Stencil(self, node: oir.Stencil) -> tir.TreeRoot:
# than persistent will yield issues with memory leaks.
containers[field.name] = data.Array(
dtype=utils.data_type_to_dace_typeclass(field.dtype),
shape=get_dace_shape(field, field_extent, k_bound, symbols),
shape=get_dace_shape(field, field_extent, k_bound),
strides=get_dace_strides(field, symbols),
transient=True,
lifetime=dtypes.AllocationLifetime.Persistent,
Expand Down Expand Up @@ -532,7 +531,6 @@ def get_dace_shape(
field: oir.FieldDecl,
extent: definitions.Extent,
k_bound: tuple[int, int],
symbols: tir.SymbolDict,
) -> list[symbolic.symbol]:
shape = []
for index, axis in enumerate(tir.Axis.dims_3d()):
Expand Down
9 changes: 9 additions & 0 deletions src/gt4py/storage/cartesian/layout_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ def register(name: str, info: LayoutInfo) -> None:
is_optimal_layout=layout_checker_factory(layout_maker_factory((2, 1, 0))),
),
)
register(
"dace:gpu_IJK",
LayoutInfo(
alignment=32,
device="gpu",
layout_map=layout_maker_factory((0, 1, 2)),
is_optimal_layout=layout_checker_factory(layout_maker_factory((0, 1, 2))),
),
)
register(
"debug",
LayoutInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Literal

import pytest

from gt4py import cartesian as gt4pyc, storage as gt_storage
Expand All @@ -28,8 +30,8 @@ def is_rocm_cupy():

@pytest.mark.parametrize("backend", ALL_BACKENDS)
@pytest.mark.parametrize("order", ["C", "F"])
def test_numpy_allocators(backend, order):
if backend in ["gt:gpu", "dace:gpu"] and is_rocm_cupy():
def test_numpy_allocators(backend: str, order: Literal["C"] | Literal["F"]) -> None:
if is_rocm_cupy() and backend in ["gt:gpu", "dace:gpu", "dace:gpu_IJK"]:
pytest.skip(
f"This test would need GT4Py's custom `__hip_array_interface__` on the cupy array. Using dlpack (via nanobind) would make this test work for ROCm."
)
Expand All @@ -46,8 +48,8 @@ def test_numpy_allocators(backend, order):


@pytest.mark.parametrize("backend", PERFORMANCE_BACKENDS)
def test_bad_layout_warns(backend):
if backend in ["gt:gpu", "dace:gpu"] and is_rocm_cupy():
def test_bad_layout_warns(backend: str) -> None:
if is_rocm_cupy() and backend in ["gt:gpu", "dace:gpu", "dace:gpu_IJK"]:
pytest.skip(
f"This test would need GT4Py's custom `__hip_array_interface__` on the cupy array. Using dlpack (via nanobind) would make this test work for ROCm."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -935,12 +935,7 @@ def test(out: Field[np.float64], inp: Field[np.float64]): # type: ignore


@pytest.mark.parametrize("backend", ALL_BACKENDS)
def test_K_offset_write(backend):
if backend in ["gt:gpu", "dace:gpu"]:
pytest.skip(
f"{backend} backend is not capable of K offset write, bug remains unsolved: https://github.com/GridTools/gt4py/issues/1684"
)

def test_K_offset_write_simple(backend: str) -> None:
arraylib = get_array_library(backend)
array_shape = (1, 1, 4)
K_values = arraylib.arange(start=40, stop=44)
Expand All @@ -960,10 +955,19 @@ def simple(A: Field[np.float64], B: Field[np.float64]): # type: ignore
B = gt_storage.zeros(
backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64
)

simple(A, B)

assert (B[:, :, 0] == 0).all()
assert (B[:, :, 1:3] == K_values[0:2]).all()


@pytest.mark.parametrize("backend", ALL_BACKENDS)
def test_K_offset_write_forward(backend: str) -> None:
arraylib = get_array_library(backend)
array_shape = (1, 1, 4)
K_values = arraylib.arange(start=40, stop=44)

# Order of operations: FORWARD with negative offset
# means while A is update B will have non-updated values of A
# Because of the interval, value of B[0] is 0
Expand All @@ -980,16 +984,29 @@ def forward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): #
B = gt_storage.zeros(
backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64
)

forward(A, B, 2.0)

assert (A[:, :, :3] == 2.0).all()
assert (A[:, :, 3] == K_values[3]).all()
assert (B[:, :, 0] == 0).all()
assert (B[:, :, 1:] == K_values[1:]).all()


@pytest.mark.parametrize("backend", ALL_BACKENDS)
def test_K_offset_write_backward(backend: str) -> None:
arraylib = get_array_library(backend)
array_shape = (1, 1, 4)
K_values = arraylib.arange(start=40, stop=44)

# Order of operations: BACKWARD with negative offset
# means A is update B will get the updated values of A
# Because of the interval, B[0] is never written
@gtscript.stencil(backend=backend)
def backward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): # type: ignore
with computation(BACKWARD), interval(-1, None):
A = scalar

with computation(BACKWARD), interval(1, None):
A[0, 0, -1] = scalar
B[0, 0, 0] = A
Expand All @@ -998,13 +1015,15 @@ def backward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): #
backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64
)
A[:, :, :] = K_values
B = gt_storage.empty(
B = gt_storage.zeros(
backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64
)

backward(A, B, 2.0)
assert (A[:, :, :3] == 2.0).all()
assert (A[:, :, 3] == K_values[3]).all()
assert (B[:, :, :] == A[:, :, :]).all()

assert (A[:, :, :] == 2.0).all()
assert (B[:, :, 0] == 0.0).all()
assert (B[:, :, 1:] == 2.0).all()


@pytest.mark.parametrize("backend", ALL_BACKENDS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -800,10 +800,7 @@ class TestVariableKAndReadOutside(gt_testing.StencilTestSuite):

def definition(field_in, field_out, index):
with computation(PARALLEL), interval(1, None):
field_out[0, 0, 0] = (
field_in[0, 0, index] # noqa: F841 [unused-variable]
+ field_in[0, 0, -2]
)
field_out[0, 0, 0] = field_in[0, 0, index] + field_in[0, 0, -2]

def validation(field_in, field_out, index, *, domain, origin):
idx = 1 + (np.arange(domain[-1]) + index)[1:]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,22 @@ def test_integer_power_of_integer() -> None:
tasklet_code = visitor.visit_NativeFuncCall(pow_call, ctx=fake_context, is_target=False)

assert "ipow" not in tasklet_code


@pytest.mark.parametrize(
"arg",
[
oir.Literal(value="2", dtype=common.DataType.FLOAT32),
oir.Literal(value="2", dtype=common.DataType.FLOAT64),
],
)
def test_log10_respects_floating_point_precision(arg: oir.Literal) -> None:
log10_call = oir.NativeFuncCall(func=common.NativeFunction.LOG10, args=[arg])

visitor = oir_to_tasklet.OIRToTasklet()
fake_context = oir_to_tasklet.Context(
code="asdf", targets=set(), inputs={}, outputs={}, tree=None, scope=None
)
tasklet_code = visitor.visit_NativeFuncCall(log10_call, ctx=fake_context, is_target=False)

assert "dace.math.log10" in tasklet_code
Loading